xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Normalization.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/DeviceUtils.cuh>
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/cuda/DeviceSqrt.cuh>
11 #include <ATen/native/cuda/LaunchUtils.h>
12 #include <c10/macros/Macros.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 namespace at { namespace native {
23 
24 // The maximum number of threads in a block
25 #if defined(USE_ROCM)
26 constexpr int MAX_BLOCK_SIZE = 256;
27 #else
28 constexpr int MAX_BLOCK_SIZE = 512;
29 #endif
30 
31 constexpr unsigned MAX_GRID_SIZE = 65535u;
32 
33 // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
getNumThreads(int nElem)34 static int getNumThreads(int nElem) {
35 #if defined(USE_ROCM)
36   int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
37 #else
38   int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
39 #endif
40   for (int i = 0; i != 5; ++i) {
41     if (nElem <= threadSizes[i]) {
42       return threadSizes[i];
43     }
44   }
45   return MAX_BLOCK_SIZE;
46 }
47 
48 // Returns the index of the most significant 1 bit in `val`.
getMSB(int val)49 __device__ __forceinline__ int getMSB(int val) {
50   return 31 - __clz(val);
51 }
52 
53 template <typename scalar_t, typename accscalar_t>
54 struct Float2 {
55   accscalar_t v1, v2;
Float2at::native::Float256   __device__ Float2() {}
Float2at::native::Float257   __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
Float2at::native::Float258   __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
operator +=at::native::Float259   __device__ Float2& operator+=(const Float2& a) {
60     v1 += a.v1;
61     v2 += a.v2;
62     return *this;
63   }
operator +(Float2 a,const Float2 & b)64   __device__ friend Float2 operator+(Float2 a, const Float2& b) {
65     a += b;
66     return a;
67   }
68 };
69 
70 template <typename scalar_t, typename accscalar_t, typename PTA>
71 struct GradOp {
GradOpat::native::GradOp72   __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
73     : mean(m), input(i), grad_output(g) {}
operator ()at::native::GradOp74   __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
75     accscalar_t g = grad_output[batch][plane][n];
76     accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
77     return Float2<scalar_t, accscalar_t>(g, g * c);
78   }
79   const accscalar_t mean;
80   const PTA& input;
81   const PTA& grad_output;
82 };
83 
84 template <typename acc_t>
85 struct SumReduceOp {
combineat::native::SumReduceOp86     __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
87 
warp_shfl_downat::native::SumReduceOp88     __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
89         return WARP_SHFL_DOWN(data, offset);
90     }
91 };
92 
93 template <typename scalar_t, typename accscalar_t>
94 struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
95     using acc_t = Float2<scalar_t, accscalar_t>;
96 
combineat::native::SumReduceOp97     __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
98 
warp_shfl_downat::native::SumReduceOp99     __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
100         return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
101     }
102 };
103 
104 // Sum across (batch, x/y/z) applying Op() pointwise
105 // this works by first having each thread sum it's part
106 // of the data. Then there is a double-shuffling reduction.
107 // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
108 // data to the "warp leader", who writes its value into shared memory.
109 // Then a single warp reads the remaining (at most C10_WARP_SIZE) items
110 // and reduces them using another warpSum.
111 // The implicit assumption is that there are no more
112 // than C10_WARP_SIZE**2 threads.
113 template<typename scalar_t, typename Op, typename PTA>
reduce(Op op,PTA tensor,int plane)114 __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115   // first the reductions each thread does separately
116   scalar_t sum = static_cast<scalar_t>(0);
117   for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119       sum += op(batch, plane, x);
120     }
121   }
122   __shared__ scalar_t shared[C10_WARP_SIZE];
123   SumReduceOp<scalar_t> reduce_op;
124   sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
125   if (threadIdx.x == 0 && threadIdx.y == 0) {
126       shared[0] = sum;
127   }
128   __syncthreads();
129   // Everyone picks it up, should be broadcast into the whole grad_input
130   return shared[0];
131 }
132 
133 constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
134 constexpr int ELEMENTS_PER_THREAD = 16;
135 constexpr int OPTIMAL_TILE_W = 32;
136 constexpr int MAX_H_BLOCK = 128;
137 
flexible_launch_configs(const int reduction,const int stride,dim3 & block,dim3 & grid,const bool coop_flag=false)138 __host__ void flexible_launch_configs(
139       const int reduction,
140       const int stride,
141       dim3 &block,
142       dim3 &grid,
143       const bool coop_flag = false) {
144   int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
145   int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
146                          MAX_BLOCK_SIZE / block_x);
147   if (block_x * block_y != MAX_BLOCK_SIZE) {
148     block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
149   }
150 
151   int grid_x = at::ceil_div(stride, block_x);
152   int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
153   if (coop_flag) {
154     // it's not worth having a grid reduction if the reduction dimension is not big enough
155     grid_y = grid_y < 8 ? 1 : grid_y;
156   }
157 
158   block.x = block_x;
159   block.y = block_y;
160   block.z = 1;
161   grid.x = grid_x;
162   grid.y = grid_y;
163   grid.z = 1;
164 }
165 
166 template<typename T, typename C>
welford_merge_element(C & count,T & mean,T & m2n,const C & count_new,const T & mean_new,const T & m2n_new)167 __device__ __forceinline__ void welford_merge_element(C& count,
168                                                       T& mean,
169                                                       T& m2n,
170                                                       const C& count_new,
171                                                       const T& mean_new,
172                                                       const T& m2n_new) {
173       T factor = T(1.0) / ::max(1, (count + count_new));
174       T delta0 = mean - mean_new;
175       mean = (mean_new * count_new + mean * count) * factor;
176       m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
177       count += count_new;
178 }
179 
180 // merge mean/m2n among threadIdx.y within block
181 template<typename T, typename C>
welford_merge_block_vertical(C & count,T & mean,T & m2n,C * shmem_count,T * shmem_mean,T * shmem_m2n)182 __device__ __forceinline__ void welford_merge_block_vertical(C& count,
183                                                              T& mean,
184                                                              T& m2n,
185                                                              C* shmem_count,
186                                                              T* shmem_mean,
187                                                              T* shmem_m2n) {
188   // write to shared memory
189   auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
190 
191 #pragma unroll
192   for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
193     if (threadIdx.y < offset*2) {
194       shmem_mean[address_base] = mean;
195       shmem_m2n[address_base] = m2n;
196       shmem_count[address_base] = count;
197     }
198     __syncthreads();
199     if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
200       auto address = address_base + offset * blockDim.x;
201       // read shared memory back to register for reduction
202       auto count_new = shmem_count[address];
203       auto mean_new = shmem_mean[address];
204       auto m2n_new = shmem_m2n[address];
205 
206       welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
207     }
208   }
209 }
210 
211 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
batch_norm_transform_input_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,RestrictPtrTraits,index_t> input,GenericPackedTensorAccessor<input_scalar_t,3,RestrictPtrTraits,index_t> output,const GenericPackedTensorAccessor<typename std::conditional<train,stat_accscalar_t,stat_scalar_t>::type,1,RestrictPtrTraits,index_t> mean_,const GenericPackedTensorAccessor<typename std::conditional<train,stat_accscalar_t,stat_scalar_t>::type,1,RestrictPtrTraits,index_t> var_or_invstd,const GenericPackedTensorAccessor<const stat_scalar_t,1,RestrictPtrTraits,index_t> weight,const GenericPackedTensorAccessor<const stat_scalar_t,1,RestrictPtrTraits,index_t> bias,stat_accscalar_t epsilon)212 __global__ void batch_norm_transform_input_kernel(
213     const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
214     GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
215     const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
216     const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
217     const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
218     const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
219     stat_accscalar_t epsilon) {
220 
221   index_t plane = blockIdx.x;
222 
223   if (plane >= input.size(1)) {
224     return;
225   }
226 
227   stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
228   stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
229   stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
230   stat_accscalar_t invstd;
231   if (train) {
232     invstd = var_or_invstd[plane];
233   } else {
234     invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
235   }
236 
237   index_t bs = input.size(0);
238   index_t fs = input.size(2);
239 
240   index_t bstep  = blockDim.y * gridDim.y;
241   for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
242     auto o = output[batch][plane];
243     auto i = input[batch][plane];
244     for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
245       o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
246     }
247   }
248 }
249 
250 struct InvStd {
251   template <typename T>
operator ()at::native::InvStd252   __device__ __forceinline__ T operator()(T var, double epsilon) const {
253     T invstd = 0;
254     if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
255       invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
256     }
257     return invstd;
258   }
259 };
260 
261 struct Var {
262   template <typename T>
operator ()at::native::Var263   __device__ __forceinline__ T operator()(T var, double epsilon) const {
264     return var;
265   }
266 };
267 
268 template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_collect_statistics_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,RestrictPtrTraits,index_t> input,const stat_accscalar_t epsilon,const stat_accscalar_t momentum,GenericPackedTensorAccessor<stat_accscalar_t,1,RestrictPtrTraits,index_t> save_mean,GenericPackedTensorAccessor<stat_accscalar_t,1,RestrictPtrTraits,index_t> save_transformed_var)269 __global__ void batch_norm_collect_statistics_kernel(
270     const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
271     const stat_accscalar_t epsilon,
272     const stat_accscalar_t momentum,
273     GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
274     GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
275 
276   __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
277 
278   int plane = blockIdx.x;
279   int N = input.size(0) * input.size(2);
280   int tid = threadIdx.x + threadIdx.y * blockDim.x;
281 
282   // Compute the mean and variance across (batch, x/y/z)
283   // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
284   // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
285   // and the parallel algorithm on the same page.
286   // We use two shuffles to reduce across the entire block.
287   // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
288   stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
289 
290   // first the reductions each thread does separately
291   stat_accscalar_t avg = 0;
292   stat_accscalar_t var_n = 0;
293   int n = 0;
294   for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
295     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
296       stat_accscalar_t v = input[batch][plane][x];
297       stat_accscalar_t d1 = v - avg;
298       n++;
299       avg += d1 / n;
300       var_n += d1 * (v - avg);
301     }
302   }
303 
304   // first warpSum to get one value per thread to
305   // one value per warp
306   for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
307     stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
308     int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
309     stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
310     var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
311     avg = (n * avg + o_n * o_avg) * factor;
312     n += o_n;
313   }
314 
315   // this writes each warps  item into shared memory
316   // there are at most C10_WARP_SIZE items left because
317   // there are at most C10_WARP_SIZE**2 threads at the beginning
318   __syncthreads();
319   if (tid % C10_WARP_SIZE == 0) {
320     shared_n[tid / C10_WARP_SIZE] = n;
321     shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
322     shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
323   }
324   __syncthreads();
325   // now have a second warpSum to reduce the intermediate values
326   // from shared memory to a single number. The very first
327   // thread writes it to shared memory.
328 
329   if (tid < C10_WARP_SIZE) {
330     n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
331     avg = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
332     var_n = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
333   }
334   for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
335     stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
336     int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
337     stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
338     var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
339     avg = (n * avg + o_n * o_avg) * factor;
340     n += o_n;
341   }
342 
343   // Save the mean, variance, and moving averages
344   if (tid == 0) {
345     if (save_mean.data() != NULL) {
346       save_mean[plane] = avg;
347     }
348     if (save_transformed_var.data() != NULL) {
349       save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
350     }
351   }
352 
353 }
354 
355 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<const input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_weight,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_bias,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> running_mean,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> running_var,const GenericPackedTensorAccessor<const stat_accscalar_t,1,DefaultPtrTraits,index_t> save_mean,const GenericPackedTensorAccessor<const stat_accscalar_t,1,DefaultPtrTraits,index_t> save_invstd,bool train,stat_accscalar_t epsilon)356 __global__ void batch_norm_backward_kernel(
357     const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> input,
358     const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
359     GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
360     GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
361     GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
362     const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
363     const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
364     const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
365     const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
366     const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
367     bool train,
368     stat_accscalar_t epsilon) {
369 
370   index_t plane = blockIdx.x;
371   index_t N = grad_output.size(0) * grad_output.size(2);
372 
373   stat_accscalar_t mean, invstd;
374   if (train) {
375     mean = save_mean[plane];
376     invstd = save_invstd[plane];
377   } else {
378     mean = static_cast<stat_accscalar_t>(running_mean[plane]);
379     invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
380   }
381 
382   stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
383   stat_accscalar_t norm = stat_accscalar_t(1) / N;
384 
385   // Compute two values across (batch, x/y/z) in one pass:
386   // 1. Sum(grad_output)
387   // 2. DotProduct(input - mean, grad_output)
388   GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
389   auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
390 
391   stat_accscalar_t grad_output_sum = res.v1;
392   stat_accscalar_t dot_p = res.v2;
393 
394   stat_accscalar_t grad_mean = grad_output_sum * norm;
395   stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
396   stat_accscalar_t grad_scale = invstd * weight_val;
397 
398   if (grad_input.data() != NULL) {
399     for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
400       for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
401         input_scalar_t go = grad_output[batch][plane][x];
402         if (train) {
403           stat_accscalar_t inp = input[batch][plane][x];
404           stat_accscalar_t proj = (inp - mean) * proj_scale;
405           grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
406         } else {
407           grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
408         }
409       }
410     }
411   }
412 
413   if (grad_weight.size(0) > 0) {
414     if (threadIdx.x == 0) {
415       grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
416     }
417   }
418 
419   if (grad_bias.size(0) > 0) {
420     if (threadIdx.x == 0) {
421       grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
422     }
423   }
424 }
425 
426 template <typename scalar_t, typename accscalar_t, typename index_t>
batch_norm_reduce_statistics_kernel(const GenericPackedTensorAccessor<accscalar_t,2,RestrictPtrTraits,index_t> vec_mean,const GenericPackedTensorAccessor<accscalar_t,2,RestrictPtrTraits,index_t> vec_invstd,GenericPackedTensorAccessor<accscalar_t,1,RestrictPtrTraits,index_t> mean,GenericPackedTensorAccessor<accscalar_t,1,RestrictPtrTraits,index_t> invstd,GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> running_mean,GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> running_var,const accscalar_t epsilon,const accscalar_t momentum,const GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> counts)427 __global__ void batch_norm_reduce_statistics_kernel(
428     const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
429     const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
430     GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
431     GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
432     GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
433     GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
434     const accscalar_t epsilon,
435     const accscalar_t momentum,
436     const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
437 
438   int feature_size = vec_mean.size(1);
439   int world_size = vec_mean.size(0);
440 
441   int bid = blockIdx.x;
442   int tid = threadIdx.x;
443 
444   // first the reductions each thread does separately
445   for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
446     accscalar_t avg = 0;
447     accscalar_t var_n = 0;
448     index_t n = 0;
449     for (int j = 0; j < world_size; j++) {
450       scalar_t count = counts[j];
451       accscalar_t m = vec_mean[j][i];
452       accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
453       v = (v * v - epsilon) * count;
454       accscalar_t factor = 1.0 / (n + count);
455       var_n += v + (avg - m) * (avg - m) * n * count * factor;
456       avg = n * factor * avg + count * factor * m;
457       n += count;
458     }
459     mean[i] = avg;
460     invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
461     if (running_mean.data() != NULL) {
462       running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
463     }
464     accscalar_t unbiasedVar = var_n / (n - 1);
465     if (running_var.data() != NULL) {
466       running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
467     }
468   }
469 
470 }
471 
472 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_reduce_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_weight,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_bias)473 __global__ void batch_norm_backward_reduce_kernel(
474     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
475     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
476     GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
477     GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
478     GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
479     GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
480     GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
481     GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
482 
483   index_t plane = blockIdx.x;
484 
485   stat_accscalar_t r_mean = mean[plane];
486   stat_accscalar_t factor = invstd[plane];
487 
488   GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
489   auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
490 
491   if (threadIdx.x == 0) {
492     if (grad_weight.size(0) > 0) {
493       grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
494     }
495     if (grad_bias.size(0) > 0) {
496       grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
497     }
498     if (sum_dy.size(0) > 0) {
499       sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
500     }
501     if (sum_dy_xmu.size(0) > 0) {
502       sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
503     }
504   }
505 }
506 
507 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel_impl(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const stat_accscalar_t norm_fct)508 __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
509     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
510     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
511     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
512     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
513     const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
514     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
515     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
516     GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
517     const stat_accscalar_t norm_fct) {
518   index_t plane = blockIdx.x;
519 
520   if (plane >= input.size(1)) {
521     return;
522   }
523 
524   stat_accscalar_t m_c = mean[plane];
525   stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
526   stat_accscalar_t factor_1_c = invstd[plane];
527   stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
528   factor_2_c *= factor_1_c;
529   factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
530 
531   index_t bs = input.size(0);
532   index_t fs = input.size(2);
533 
534   index_t bstep  = blockDim.y * gridDim.y;
535   for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
536     auto g_i = grad_input[batch][plane];
537     auto g_o = grad_output[batch][plane];
538     auto i = input[batch][plane];
539     for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
540       g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
541     }
542   }
543 }
544 
545 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const int * __restrict__ numel,const int world_size)546 __global__ void batch_norm_backward_elemt_kernel(
547     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
548     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
549     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
550     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
551     const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
552     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
553     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
554     GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
555     const int* __restrict__ numel, const int world_size) {
556   int64_t total_numel = 0;
557   for (int i = 0; i < world_size; i ++) {
558     total_numel += numel[i];
559   }
560 
561   const stat_accscalar_t norm_fct =
562       static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
563   batch_norm_backward_elemt_kernel_impl(
564       input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
565 }
566 
567 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const stat_accscalar_t norm_fct)568 __global__ void batch_norm_backward_elemt_kernel(
569     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
570     const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
571     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
572     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
573     const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
574     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
575     const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
576     GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
577     const stat_accscalar_t norm_fct) {
578   batch_norm_backward_elemt_kernel_impl(
579       input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
580 }
581 
582 template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
get_packed_accessor(const Tensor & t,c10::string_view var_name)583 static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
584     const Tensor& t, c10::string_view var_name) {
585   constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
586   const auto actual_type = t.scalar_type();
587   TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
588               " to have type ", expect_type, " but got ", actual_type);
589   return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
590 }
591 
592 template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
packed_accessor_or_dummy(const Tensor & t,c10::string_view var_name)593 static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
594     const Tensor& t, c10::string_view var_name) {
595   if (!t.defined()) {
596     const std::array<index_t, dim> zeros{{0}};
597     return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
598   }
599   return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
600 }
601 
602 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & weight_,const Tensor & running_mean_,const Tensor & running_var_,const Tensor & save_mean_,const Tensor & save_invstd_,bool train,double epsilon,std::array<bool,3> grad_input_mask)603 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
604                                                                      const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
605                                                                      bool train, double epsilon, std::array<bool,3> grad_input_mask) {
606 
607   using accscalar_t = at::acc_type<stat_scalar_t, true>;
608   Tensor grad_input_;
609   Tensor grad_input_reshaped;
610   Tensor grad_weight_;
611   Tensor grad_bias_;
612   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
613   auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
614 
615   if (grad_input_mask[0]) {
616     grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
617     grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
618   }
619   if (grad_input_mask[1]) {
620     grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
621   }
622   if (grad_input_mask[2]) {
623     grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
624   }
625 
626   auto input = get_packed_accessor<
627       const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
628   auto grad_output = get_packed_accessor<
629       const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
630   auto grad_input = packed_accessor_or_dummy<
631       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
632   auto weight = packed_accessor_or_dummy<
633       const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
634   auto grad_weight = packed_accessor_or_dummy<
635       stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
636   auto grad_bias = packed_accessor_or_dummy<
637       stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
638   auto running_mean = packed_accessor_or_dummy<
639       const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
640   auto running_var = packed_accessor_or_dummy<
641       const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
642   auto save_mean = packed_accessor_or_dummy<
643       const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
644   auto save_invstd = packed_accessor_or_dummy<
645       const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
646 
647   auto stream = at::cuda::getCurrentCUDAStream();
648   dim3 blocks(input.size(1));
649   int tf = getNumThreads(input.size(2));
650   dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
651 
652   batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
653     (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
654      save_mean, save_invstd, train, epsilon);
655   C10_CUDA_KERNEL_LAUNCH_CHECK();
656 
657   return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
658 }
659 
660 template<typename scalar_t, typename index_t, typename VarTransform>
batch_norm_stats_cuda_template(const Tensor & out_mean,const Tensor & out_invstd,const Tensor & input_,double epsilon)661 void batch_norm_stats_cuda_template(
662     const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
663 
664   using accscalar_t = at::acc_type<scalar_t, true>;
665   int64_t n_input = input_.size(1);
666   Tensor dummy_mean_;
667   Tensor dummy_var_;
668   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
669 
670   resize_output(out_mean, {n_input});
671   resize_output(out_invstd, {n_input});
672   auto input = get_packed_accessor<
673       const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
674   TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
675                         out_invstd.sizes()[0]);
676   TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
677                         out_mean.sizes()[0]);
678 
679   auto mean = packed_accessor_or_dummy<
680       accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
681   auto invstd = packed_accessor_or_dummy<
682       accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
683   auto stream = at::cuda::getCurrentCUDAStream();
684 
685   dim3 blocks(input.size(1));
686   int tf = getNumThreads(input.size(2));
687   dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
688   batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
689     (input, epsilon, 0.0, mean, invstd);
690   C10_CUDA_KERNEL_LAUNCH_CHECK();
691 }
692 
693 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_elemt_cuda_template(const Tensor & output_,const Tensor & input_,const Tensor & weight_,const Tensor & bias_,const Tensor & mean_,const Tensor & invstd_)694 void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
695                                     const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
696 
697   using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
698   int64_t n_input = input_.size(1);
699   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
700   auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
701 
702   auto input = get_packed_accessor<
703       const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
704   auto output = get_packed_accessor<
705       input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
706   auto weight = packed_accessor_or_dummy<
707     const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
708   auto bias = packed_accessor_or_dummy<
709       const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
710   auto mean = packed_accessor_or_dummy<
711       stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
712   auto invstd = packed_accessor_or_dummy<
713       stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
714   auto stream = at::cuda::getCurrentCUDAStream();
715 
716   // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
717   const double dummy_epsilon = 1e-5;
718 
719   // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
720   // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
721   // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
722   // The various planes are independent, so we use blocks for them.
723   int tf = std::max<int>(getNumThreads(input.size(2)/4),
724                          std::min<int>(getNumThreads(input.size(2)), 64));
725   int tb = std::max<int>(64/tf, 1);
726   dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
727                                                                   (input.size(0)+tb-1)/tb)));
728   blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
729   dim3 threads_trans(tf, tb);
730   batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
731     (input, output, mean, invstd, weight, bias, dummy_epsilon);
732   C10_CUDA_KERNEL_LAUNCH_CHECK();
733 }
734 
735 template<typename scalar_t, typename accscalar_t, typename index_t>
batch_norm_gather_stats_cuda_template(const Tensor & mean_,const Tensor & invstd_,const Tensor & running_mean_,const Tensor & running_var_,double momentum,double epsilon,const Tensor & counts_)736 std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
737                                                                  const Tensor& running_mean_, const Tensor& running_var_,
738                                                                  double momentum, double epsilon, const Tensor& counts_) {
739 
740   Tensor save_mean_;
741   Tensor save_invstd_;
742 
743   auto features = mean_.size(1);
744   auto input_options = mean_.options();
745   if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
746     input_options = input_options.dtype(ScalarType::Float);
747   }
748   save_mean_ = at::empty({features}, input_options);
749   save_invstd_ = at::empty({features}, input_options);
750 
751   auto mean = packed_accessor_or_dummy<
752       accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
753   auto invstd = packed_accessor_or_dummy<
754       accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
755   auto running_mean = packed_accessor_or_dummy<
756       scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
757   auto running_var = packed_accessor_or_dummy<
758       scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
759   auto counts = packed_accessor_or_dummy<
760       scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
761 
762   auto save_mean = get_packed_accessor<
763       accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
764   auto save_invstd = get_packed_accessor<
765       accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
766   auto stream = at::cuda::getCurrentCUDAStream();
767 
768   int block = getNumThreads(features);
769   int grid = std::max<int>(1, features/block);
770   batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
771       (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
772   C10_CUDA_KERNEL_LAUNCH_CHECK();
773 
774   return std::make_tuple(save_mean_, save_invstd_);
775 }
776 
777 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_reduce_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const bool input_g,const bool weight_g,const bool bias_g)778 std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
779                                                                                     const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
780                                                                                     const bool input_g, const bool weight_g, const bool bias_g) {
781 
782   using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
783   int64_t n_input = input_.size(1);
784   Tensor sum_dy_;
785   Tensor sum_dy_xmu_;
786   Tensor grad_weight_;
787   Tensor grad_bias_;
788   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
789   auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
790 
791   if (input_g) {
792     sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
793     sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
794   }
795   if (weight_g) {
796     grad_weight_ = at::empty({n_input}, weight_.options());
797   }
798   if (bias_g) {
799     grad_bias_ = at::empty({n_input}, weight_.options());
800   }
801 
802   auto input = get_packed_accessor<
803       input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
804   auto grad_output = get_packed_accessor<
805       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
806   auto grad_weight = packed_accessor_or_dummy<
807       stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
808   auto grad_bias = packed_accessor_or_dummy<
809       stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
810   auto mean = packed_accessor_or_dummy<
811       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
812   auto invstd = packed_accessor_or_dummy<
813       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
814   auto sum_dy = packed_accessor_or_dummy<
815       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
816   auto sum_dy_xmu = packed_accessor_or_dummy<
817       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
818 
819   auto batch_size = input_reshaped.size(0);
820   auto feature_size = input_reshaped.size(2);
821   auto stream = at::cuda::getCurrentCUDAStream();
822 
823   int warp_size = at::cuda::warp_size();
824   int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
825   // We want block_x to be at least a warp width
826   int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
827   const dim3 block(block_x, block_y);
828   const dim3 grid(n_input);
829 
830   batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
831     (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
832   C10_CUDA_KERNEL_LAUNCH_CHECK();
833 
834   return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
835 }
836 
837 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_elemt_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const Tensor & sum_dy_,const Tensor & sum_dy_xmu_)838 Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
839                                                const Tensor& mean_, const Tensor& invstd_,
840                                                const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
841 
842   using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
843   int64_t n_input = input_.size(1);
844   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
845   auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
846   auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
847 
848   auto input = get_packed_accessor<
849       input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
850   auto grad_input = get_packed_accessor<
851       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
852   auto grad_output = get_packed_accessor<
853       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
854   auto mean = packed_accessor_or_dummy<
855       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
856   auto invstd = packed_accessor_or_dummy<
857       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
858   auto weight = packed_accessor_or_dummy<
859       stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
860   auto sum_dy = packed_accessor_or_dummy<
861       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
862   auto sum_dy_xmu = packed_accessor_or_dummy<
863       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
864 
865   auto stream = at::cuda::getCurrentCUDAStream();
866 
867   // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
868   // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
869   // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
870   // The various planes are independent, so we use blocks for them.
871   int tf = std::max<int>(getNumThreads(input.size(2)/4),
872                          std::min<int>(getNumThreads(input.size(2)), 64));
873   int tb = std::max<int>(64/tf, 1);
874   dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
875                                                                   (input.size(0)+tb-1)/tb)));
876   blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
877   dim3 threads_trans(tf, tb);
878   auto reduction_size = input_.numel() / n_input;
879   auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
880   batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
881       <<<blocks_trans, threads_trans, 0, stream>>>
882       (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
883   C10_CUDA_KERNEL_LAUNCH_CHECK();
884 
885   return grad_input_reshaped.view(input_.sizes());
886 }
887 
888 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_elemt_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const Tensor & sum_dy_,const Tensor & sum_dy_xmu_,const Tensor & count)889 Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
890                                                const Tensor& mean_, const Tensor& invstd_,
891                                                const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
892 
893   using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
894   int64_t n_input = input_.size(1);
895   auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
896   auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
897   auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
898 
899   auto input = get_packed_accessor<
900       input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
901   auto grad_input = get_packed_accessor<
902       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
903   auto grad_output = get_packed_accessor<
904       input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
905   auto mean = packed_accessor_or_dummy<
906       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
907   auto invstd = packed_accessor_or_dummy<
908       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
909   auto weight = packed_accessor_or_dummy<
910       stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
911   auto sum_dy = packed_accessor_or_dummy<
912       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
913   auto sum_dy_xmu = packed_accessor_or_dummy<
914       stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
915 
916   auto stream = at::cuda::getCurrentCUDAStream();
917 
918   // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
919   // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
920   // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
921   // The various planes are independent, so we use blocks for them.
922   int tf = std::max<int>(getNumThreads(input.size(2)/4),
923                          std::min<int>(getNumThreads(input.size(2)), 64));
924   int tb = std::max<int>(64/tf, 1);
925   dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
926                                                                   (input.size(0)+tb-1)/tb)));
927   blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
928   dim3 threads_trans(tf, tb);
929   batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
930     (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr<int>(), count.numel());
931   C10_CUDA_KERNEL_LAUNCH_CHECK();
932 
933   return grad_input_reshaped.view(input_.sizes());
934 }
935 
936 // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
937 // original apex name: welford_kernel_c_last
938 template
939    <typename VarTransform,
940     typename scalar_t,
941     typename accscalar_t,
942     int PARALLEL_LOADS>
943 __global__ void
batch_norm_collect_statistics_channels_last_kernel(const scalar_t * __restrict__ input,accscalar_t * __restrict__ out_mean,accscalar_t * __restrict__ out_invstd,volatile accscalar_t * staging_data,int * semaphores,const int reduction_size,const int stride,accscalar_t epsilon)944 batch_norm_collect_statistics_channels_last_kernel(
945       const scalar_t* __restrict__ input,
946       accscalar_t* __restrict__ out_mean,
947       accscalar_t* __restrict__ out_invstd,
948       volatile accscalar_t* staging_data,
949       int* semaphores,
950       const int reduction_size,
951       const int stride,
952       accscalar_t epsilon) {
953   // hide latency with concurrency
954   accscalar_t x_mean[PARALLEL_LOADS];
955   accscalar_t m_2_n[PARALLEL_LOADS];
956   int count[PARALLEL_LOADS];
957 
958 #pragma unroll
959   for (int i = 0; i < PARALLEL_LOADS; i++) {
960     x_mean[i] = accscalar_t(0);
961     m_2_n[i] = accscalar_t(0);
962     count[i] = accscalar_t(0);
963   }
964   // tensor dimension (m,c)
965 
966   // loop along m dimension
967   int inner_loop_stride = blockDim.y * gridDim.y;
968 
969   // offset along m dimension
970   int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
971   int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
972 
973   int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
974   int address_base = m_offset * stride + c_offset;
975   int address_increment = inner_loop_stride * stride;
976 
977   for (int i = 0; i < loop_count; i++) {
978     accscalar_t x_math[PARALLEL_LOADS];
979     accscalar_t x_count_inv[PARALLEL_LOADS];
980     accscalar_t is_valid[PARALLEL_LOADS];
981 
982     // load multiple data in
983 #pragma unroll
984     for (int j = 0; j < PARALLEL_LOADS; j++) {
985       if (c_offset < stride && m_offset < reduction_size) {
986         x_math[j] = input[address_base];
987         count[j]++;
988         x_count_inv[j] = accscalar_t(1) / count[j];
989         is_valid[j] = accscalar_t(1);
990       } else {
991         x_math[j] = accscalar_t(0);
992         x_count_inv[j] = accscalar_t(0);
993         is_valid[j] = accscalar_t(0);
994       }
995       m_offset += inner_loop_stride;
996       address_base += address_increment;
997     }
998 
999     // calculate mean/m2n with welford
1000 #pragma unroll
1001     for (int j = 0; j < PARALLEL_LOADS; j++) {
1002       accscalar_t delta0 = x_math[j] - x_mean[j];
1003       x_mean[j] += delta0 * x_count_inv[j];
1004       accscalar_t delta1 = x_math[j] - x_mean[j];
1005       m_2_n[j] += delta0 * delta1 * is_valid[j];
1006     }
1007   }
1008 
1009   // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
1010 #pragma unroll
1011   for (int j = 1; j < PARALLEL_LOADS; j++) {
1012     welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
1013   }
1014 
1015   // release x_mean / m_2_n
1016   auto mean_th = x_mean[0];
1017   auto m2_th = m_2_n[0];
1018   auto count_th = count[0];
1019 
1020   // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1021   static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
1022   static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
1023   static __shared__ int shmem_count[MAX_BLOCK_SIZE];
1024 
1025   welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1026 
1027   if (gridDim.y > 1) {
1028     volatile accscalar_t* staging_mean = staging_data;
1029     volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
1030     volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
1031 
1032     address_base = c_offset + blockIdx.y * stride;
1033     // write data to staging_data;
1034     if (threadIdx.y == 0 && c_offset < stride) {
1035       staging_mean[address_base] = mean_th;
1036       staging_m2n[address_base] = m2_th;
1037       staging_count[address_base] = count_th;
1038     }
1039 
1040     __threadfence();
1041     __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1042 
1043     __shared__ bool is_last_block_done;
1044     // mark block done
1045     if (threadIdx.x == 0 && threadIdx.y == 0) {
1046       int old = atomicAdd(&semaphores[blockIdx.x], 1);
1047       is_last_block_done = (old == (gridDim.y-1));
1048     }
1049 
1050     __syncthreads();
1051 
1052     // check that all data is now available in global memory
1053     if (is_last_block_done) {
1054       count_th = 0;
1055       mean_th = accscalar_t(0.0);
1056       m2_th = accscalar_t(0.0);
1057 
1058       for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1059         address_base = c_offset + y * stride;
1060         int count_new = c_offset < stride ? staging_count[address_base] : 0;
1061         accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
1062         accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
1063 
1064         welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
1065       }
1066 
1067       welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1068       if (threadIdx.y == 0 && c_offset < stride) {
1069         out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1070         out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1071       }
1072     }
1073   } else {
1074     if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1075       out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1076       out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1077     }
1078   }
1079 }
1080 
1081 // elementwise BN kernel
1082 // original apex name: batchnorm_forward_c_last_kernel
1083 template <
1084     typename scalar_t,
1085     typename accscalar_t,
1086     typename layerscalar_t,
1087     int PARALLEL_LOADS>
batch_norm_transform_input_channels_last_kernel(const scalar_t * __restrict__ input,const scalar_t * __restrict__ z,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const layerscalar_t * __restrict__ shift,scalar_t * __restrict__ out,const int reduction_size,const int stride,const bool fuse_relu)1088 __global__ void batch_norm_transform_input_channels_last_kernel(
1089       const scalar_t* __restrict__ input,
1090       const scalar_t* __restrict__ z,
1091       const accscalar_t* __restrict__ mean,
1092       const accscalar_t* __restrict__ inv_std,
1093       const layerscalar_t* __restrict__ weight,
1094       const layerscalar_t* __restrict__ shift,
1095       scalar_t* __restrict__ out,
1096       const int reduction_size,
1097       const int stride,
1098       const bool fuse_relu) {
1099   // tensor dimension (m,c)
1100   // loop along m dimension
1101   int inner_loop_stride = blockDim.y * gridDim.y;
1102 
1103   // offset along m dimension
1104   int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1105   int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1106 
1107   if (c_offset >= stride || m_offset >= reduction_size) {
1108     return;
1109   }
1110 
1111   auto m_c = mean[c_offset];
1112   auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
1113   auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
1114   auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
1115 
1116   int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1117   int address_base = m_offset * stride + c_offset;
1118   int address_increment = inner_loop_stride * stride;
1119 
1120   for (int i = 0; i < loop_count; i++) {
1121 #pragma unroll
1122     for (int j = 0; j < PARALLEL_LOADS; j++) {
1123       if (c_offset < stride && m_offset < reduction_size) {
1124         auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
1125         if (z != nullptr) {
1126           tmp += z[address_base];
1127         }
1128         out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
1129       }
1130       m_offset += inner_loop_stride;
1131       address_base += address_increment;
1132     }
1133   }
1134 }
1135 
1136 template<typename T>
merge_block_vertical_backward(T & sum_dy,T & sum_dy_xmu,T * shmem_sum_dy,T * shmem_sum_dy_xmu)1137 __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
1138     T& sum_dy_xmu,
1139     T* shmem_sum_dy,
1140     T* shmem_sum_dy_xmu) {
1141   // write to shared memory
1142   auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
1143 
1144 #pragma unroll
1145   for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
1146     if (threadIdx.y < offset*2) {
1147       shmem_sum_dy[address_base] = sum_dy;
1148       shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
1149     }
1150     __syncthreads();
1151     if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
1152       auto address = address_base + offset * blockDim.x;
1153 
1154       sum_dy += shmem_sum_dy[address];
1155       sum_dy_xmu += shmem_sum_dy_xmu[address];
1156     }
1157   }
1158 }
1159 
1160 // batchnorm backward kernel for c last tensor
1161 // original apex name: reduce_bn_c_last_kernel
1162 template <
1163     int PARALLEL_LOADS,
1164     typename scalar_t,
1165     typename accscalar_t,
1166     typename layerscalar_t>
batch_norm_backward_reduce_channels_last_kernel(const scalar_t * __restrict__ input,const scalar_t * __restrict__ grad_output,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,accscalar_t * __restrict__ sum_dy_o,accscalar_t * __restrict__ sum_dy_xmu_o,layerscalar_t * __restrict__ grad_weight,layerscalar_t * __restrict__ grad_bias,volatile accscalar_t * staging_data,int * semaphores,const int reduction_size,const int stride)1167 __global__ void batch_norm_backward_reduce_channels_last_kernel(
1168       const scalar_t* __restrict__ input,
1169       const scalar_t* __restrict__ grad_output,
1170       const accscalar_t* __restrict__ mean,
1171       const accscalar_t* __restrict__ inv_std,
1172       accscalar_t* __restrict__ sum_dy_o,
1173       accscalar_t* __restrict__ sum_dy_xmu_o,
1174       layerscalar_t* __restrict__ grad_weight,
1175       layerscalar_t* __restrict__ grad_bias,
1176       volatile accscalar_t* staging_data,
1177       int* semaphores,
1178       const int reduction_size,
1179       const int stride) {
1180 
1181   // hide latency with concurrency
1182   accscalar_t sum_dy[PARALLEL_LOADS];
1183   accscalar_t sum_dy_xmu[PARALLEL_LOADS];
1184 
1185 #pragma unroll
1186   for (int i = 0; i < PARALLEL_LOADS; i++) {
1187     sum_dy[i] = accscalar_t(0);
1188     sum_dy_xmu[i] = accscalar_t(0);
1189   }
1190   // tensor dimension (m,c)
1191 
1192   // loop along m dimension
1193   int inner_loop_stride = blockDim.y * gridDim.y;
1194 
1195   // offset along m dimension
1196   int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1197   int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1198 
1199   if (c_offset >= stride || m_offset >= reduction_size) {
1200     return;
1201   }
1202 
1203   int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1204   int address_base = m_offset * stride + c_offset;
1205   int address_increment = inner_loop_stride * stride;
1206 
1207   auto r_mean = mean[c_offset];
1208   auto factor = inv_std[c_offset];
1209 
1210   for (int i = 0; i < loop_count; i++) {
1211     accscalar_t x_input[PARALLEL_LOADS];
1212     accscalar_t x_grad_output[PARALLEL_LOADS];
1213 
1214     // load multiple data in
1215 #pragma unroll
1216     for (int j = 0; j < PARALLEL_LOADS; j++) {
1217       if (c_offset < stride && m_offset < reduction_size) {
1218         x_input[j] = input[address_base];
1219         x_grad_output[j] = grad_output[address_base];
1220       } else {
1221         x_input[j] = accscalar_t(0);
1222         x_grad_output[j] = accscalar_t(0);
1223       }
1224       m_offset += inner_loop_stride;
1225       address_base += address_increment;
1226     }
1227 
1228     // calculate sum_dy / sum_dy_xmu
1229 #pragma unroll
1230     for (int j = 0; j < PARALLEL_LOADS; j++) {
1231       sum_dy[j] += x_grad_output[j];
1232       sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
1233     }
1234   }
1235 
1236   // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
1237 #pragma unroll
1238   for (int j = 1; j < PARALLEL_LOADS; j++) {
1239     sum_dy[0] += sum_dy[j];
1240     sum_dy_xmu[0] += sum_dy_xmu[j];
1241   }
1242 
1243   // release array of registers
1244   auto sum_dy_th = sum_dy[0];
1245   auto sum_dy_xmu_th = sum_dy_xmu[0];
1246 
1247   // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1248   static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
1249   static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
1250 
1251   merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1252 
1253   if (gridDim.y > 1) {
1254     volatile accscalar_t* staging_sum_dy = staging_data;
1255     volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
1256 
1257     address_base = c_offset + blockIdx.y * stride;
1258     // write data to staging_data;
1259     if (threadIdx.y == 0 && c_offset < stride) {
1260       staging_sum_dy[address_base] = sum_dy_th;
1261       staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
1262     }
1263 
1264     __threadfence();
1265     __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1266 
1267     __shared__ bool is_last_block_done;
1268     // mark block done
1269     if (threadIdx.x == 0 && threadIdx.y == 0) {
1270       int old = atomicAdd(&semaphores[blockIdx.x], 1);
1271       is_last_block_done = (old == (gridDim.y-1));
1272     }
1273 
1274     __syncthreads();
1275 
1276     // check that all data is now available in global memory
1277     if (is_last_block_done) {
1278       sum_dy_th = accscalar_t(0.0);
1279       sum_dy_xmu_th = accscalar_t(0.0);
1280 
1281       for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1282         address_base = c_offset + y * stride;
1283         sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
1284         sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
1285       }
1286 
1287       merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1288       if (threadIdx.y == 0 && c_offset < stride) {
1289         if (grad_bias != nullptr) {
1290           grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1291         }
1292         if (grad_weight != nullptr) {
1293           grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1294         }
1295         //mean_dy[c_offset] = sum_dy_th / reduction_size;
1296         //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1297         sum_dy_o[c_offset] = sum_dy_th;
1298         sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1299       }
1300     }
1301   } else {
1302     if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1303       if (grad_bias != nullptr) {
1304         grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1305       }
1306       if (grad_weight != nullptr) {
1307         grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1308       }
1309       //mean_dy[c_offset] = sum_dy_th / reduction_size;
1310       //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1311       sum_dy_o[c_offset] = sum_dy_th;
1312       sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1313     }
1314   }
1315 }
1316 
1317 // elementwise BN kernel
1318 // original apex name: batchnorm_backward_c_last_kernel
1319 template <
1320     int PARALLEL_LOADS,
1321     typename scalar_t,
1322     typename accscalar_t,
1323     typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel_impl(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,scalar_t * __restrict__ grad_input,const accscalar_t norm_fct,const int reduction_size,const int stride)1324 __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
1325       const scalar_t* __restrict__ grad_output,
1326       const scalar_t* __restrict__ input,
1327       const accscalar_t* __restrict__ mean,
1328       const accscalar_t* __restrict__ inv_std,
1329       const layerscalar_t* __restrict__ weight,
1330       const accscalar_t* __restrict__ sum_dy,
1331       const accscalar_t* __restrict__ sum_dy_xmu,
1332       scalar_t* __restrict__ grad_input,
1333       const accscalar_t norm_fct,
1334       const int reduction_size,
1335       const int stride) {
1336   // tensor dimension (m,c)
1337   // loop along m dimension
1338   int inner_loop_stride = blockDim.y * gridDim.y;
1339 
1340   // offset along m dimension
1341   int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1342   int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1343 
1344   if (c_offset >= stride || m_offset >= reduction_size) {
1345     return;
1346   }
1347 
1348   auto m_c = mean[c_offset];
1349   auto m_dy_c = sum_dy[c_offset] * norm_fct;
1350   auto factor_1_c = inv_std[c_offset];
1351   auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
1352   factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
1353 
1354   int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1355   int address_base = m_offset * stride + c_offset;
1356   int address_increment = inner_loop_stride * stride;
1357 
1358   for (int i = 0; i < loop_count; i++) {
1359 #pragma unroll
1360     for (int j = 0; j < PARALLEL_LOADS; j++) {
1361       if (c_offset < stride && m_offset < reduction_size) {
1362         grad_input[address_base] = static_cast<scalar_t>(
1363             (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
1364             (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
1365             * factor_2_c);
1366       }
1367       m_offset += inner_loop_stride;
1368       address_base += address_increment;
1369     }
1370   }
1371 }
1372 
1373 template <
1374     int PARALLEL_LOADS,
1375     typename scalar_t,
1376     typename accscalar_t,
1377     typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,const int * __restrict__ numel,scalar_t * __restrict__ grad_input,const int64_t world_size,const int reduction_size,const int stride)1378 __global__ void batch_norm_backward_elemt_channels_last_kernel(
1379       const scalar_t* __restrict__ grad_output,
1380       const scalar_t* __restrict__ input,
1381       const accscalar_t* __restrict__ mean,
1382       const accscalar_t* __restrict__ inv_std,
1383       const layerscalar_t* __restrict__ weight,
1384       const accscalar_t* __restrict__ sum_dy,
1385       const accscalar_t* __restrict__ sum_dy_xmu,
1386       const int* __restrict__ numel,
1387       scalar_t* __restrict__ grad_input,
1388       const int64_t world_size,
1389       const int reduction_size,
1390       const int stride) {
1391 
1392   int64_t total_numel = 0;
1393   for (int i = 0; i < world_size; i++) {
1394     total_numel += numel[i];
1395   }
1396 
1397   auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
1398   batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1399       grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1400       grad_input, norm_fct, reduction_size, stride);
1401 }
1402 
1403 template <
1404     int PARALLEL_LOADS,
1405     typename scalar_t,
1406     typename accscalar_t,
1407     typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,scalar_t * __restrict__ grad_input,const accscalar_t norm_fct,const int reduction_size,const int stride)1408 __global__ void batch_norm_backward_elemt_channels_last_kernel(
1409       const scalar_t* __restrict__ grad_output,
1410       const scalar_t* __restrict__ input,
1411       const accscalar_t* __restrict__ mean,
1412       const accscalar_t* __restrict__ inv_std,
1413       const layerscalar_t* __restrict__ weight,
1414       const accscalar_t* __restrict__ sum_dy,
1415       const accscalar_t* __restrict__ sum_dy_xmu,
1416       scalar_t* __restrict__ grad_input,
1417       const accscalar_t norm_fct,
1418       const int reduction_size,
1419       const int stride) {
1420   batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1421       grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1422       grad_input, norm_fct, reduction_size, stride);
1423 }
1424 
1425 template<typename scalar_t, typename VarTransform>
batch_norm_stats_channels_last_cuda_template(const Tensor & out_mean,const Tensor & out_invstd,const Tensor & input,double epsilon)1426 void batch_norm_stats_channels_last_cuda_template(
1427     const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
1428   using accscalar_t = at::acc_type<scalar_t, true>;
1429 
1430   const auto stride = input.sizes()[1];
1431   const auto reduction_size = input.numel() / stride;
1432 
1433   resize_output(out_mean, {stride});
1434   resize_output(out_invstd, {stride});
1435   TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
1436                         out_invstd.sizes()[0]);
1437   TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
1438                         out_mean.sizes()[0]);
1439 
1440   dim3 block;
1441   dim3 grid;
1442   flexible_launch_configs(reduction_size, stride, block, grid, true);
1443 
1444   at::Tensor staging_data;
1445   at::Tensor semaphores;
1446   if (grid.y > 1) {
1447     staging_data = at::empty({4*stride*grid.y}, out_mean.options());
1448     semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1449   }
1450 
1451   accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1452   int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1453   batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
1454       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
1455       input.const_data_ptr<scalar_t>(),
1456       out_mean.mutable_data_ptr<accscalar_t>(),
1457       out_invstd.mutable_data_ptr<accscalar_t>(),
1458       staging_data_ptr,
1459       semaphores_ptr,
1460       reduction_size,
1461       stride,
1462       epsilon);
1463   C10_CUDA_KERNEL_LAUNCH_CHECK();
1464 }
1465 
batch_norm_elemt_channels_last_cuda_template(const at::Tensor & output,const at::Tensor & input,const at::Tensor & weight,const at::Tensor & shift,const at::Tensor & mean,const at::Tensor & inv_std,const std::optional<at::Tensor> & z=std::nullopt,const bool fuse_relu=false)1466 void batch_norm_elemt_channels_last_cuda_template(
1467     const at::Tensor& output,
1468     const at::Tensor& input,
1469     const at::Tensor& weight,
1470     const at::Tensor& shift,  // bias of BN
1471     const at::Tensor& mean,
1472     const at::Tensor& inv_std,
1473     const std::optional<at::Tensor>& z = std::nullopt,  // bias after BN
1474     const bool fuse_relu = false) {
1475   const auto stride = input.sizes()[1];
1476   const auto reduction_size = input.numel() / stride;
1477 
1478   dim3 block;
1479   dim3 grid;
1480   flexible_launch_configs(reduction_size, stride, block, grid);
1481 
1482   auto stream = at::cuda::getCurrentCUDAStream();
1483   const auto second_dtype = weight.defined() ? weight.scalar_type() :
1484       (shift.defined() ? shift.scalar_type() : input.scalar_type());
1485 
1486   if (input.scalar_type() != second_dtype) {
1487     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1488       using accscalar_t = at::acc_type<scalar_t, true>;
1489       batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1490           <<<grid, block, 0, stream>>>(
1491           input.const_data_ptr<scalar_t>(),
1492           z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1493           mean.const_data_ptr<accscalar_t>(),
1494           inv_std.const_data_ptr<accscalar_t>(),
1495           weight.defined() ? weight.const_data_ptr<accscalar_t>() : nullptr,
1496           shift.defined() ? shift.const_data_ptr<accscalar_t>() : nullptr,
1497           output.mutable_data_ptr<scalar_t>(),
1498           reduction_size,
1499           stride,
1500           fuse_relu);
1501       C10_CUDA_KERNEL_LAUNCH_CHECK();
1502     });
1503   } else {
1504     if (weight.defined()){
1505       TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
1506         " is not supported with weight.scalar_type() ", weight.scalar_type());
1507     }
1508     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1509       using accscalar_t = at::acc_type<scalar_t, true>;
1510       batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
1511           <<<grid, block, 0, stream>>>(
1512           input.const_data_ptr<scalar_t>(),
1513           z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1514           mean.const_data_ptr<accscalar_t>(),
1515           inv_std.const_data_ptr<accscalar_t>(),
1516           weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1517           shift.defined() ? shift.const_data_ptr<scalar_t>(): nullptr,
1518           output.mutable_data_ptr<scalar_t>(),
1519           reduction_size,
1520           stride,
1521           fuse_relu);
1522       C10_CUDA_KERNEL_LAUNCH_CHECK();
1523     });
1524   }
1525 }
1526 
1527 std::tuple<Tensor, Tensor, Tensor, Tensor>
batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const bool input_g,const bool weight_g,const bool bias_g)1528 batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
1529     const at::Tensor& input,
1530     const at::Tensor& mean,
1531     const at::Tensor& inv_std,
1532     const at::Tensor& weight,
1533     const bool input_g, const bool weight_g, const bool bias_g) {
1534   const auto stride = input.sizes()[1];
1535   const auto reduction_size = input.numel() / stride;
1536 
1537   at::Tensor sumn_dy = at::empty({stride}, mean.options());
1538   at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
1539 
1540   at::Tensor grad_weight;
1541   at::Tensor grad_bias;
1542   if (weight.defined()) {
1543     grad_weight = at::empty({stride}, weight.options());
1544     grad_bias = at::empty({stride}, weight.options());
1545   } else {
1546     // because I cannot return an uninitialized at::Tensor
1547     grad_weight = at::empty({0}, mean.options());
1548     grad_bias = at::empty({0}, mean.options());
1549   }
1550 
1551   dim3 block;
1552   dim3 grid;
1553   flexible_launch_configs(reduction_size, stride, block, grid, true);
1554 
1555   at::Tensor staging_data;
1556   at::Tensor semaphores;
1557   if (grid.y > 1) {
1558     staging_data = at::empty({2*stride*grid.y}, mean.options());
1559     semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1560   }
1561   auto stream = at::cuda::getCurrentCUDAStream();
1562 
1563   if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
1564     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1565       using accscalar_t = at::acc_type<scalar_t, true>;
1566       accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1567       int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1568       batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1569           <<<grid, block, 0, stream>>>(
1570           input.const_data_ptr<scalar_t>(),
1571           grad_output.const_data_ptr<scalar_t>(),
1572           mean.const_data_ptr<accscalar_t>(),
1573           inv_std.const_data_ptr<accscalar_t>(),
1574           sumn_dy.mutable_data_ptr<accscalar_t>(),
1575           sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1576           grad_weight.mutable_data_ptr<accscalar_t>(),
1577           grad_bias.mutable_data_ptr<accscalar_t>(),
1578           staging_data_ptr,
1579           semaphores_ptr,
1580           reduction_size,
1581           stride);
1582       C10_CUDA_KERNEL_LAUNCH_CHECK();
1583     });
1584   } else {
1585     if (weight.defined()) {
1586       TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
1587         " is not supported with weight.scalar_type() ", weight.scalar_type());
1588     }
1589     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1590       using accscalar_t = at::acc_type<scalar_t, true>;
1591       accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1592       int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1593       batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1594           <<<grid, block, 0, stream>>>(
1595           input.const_data_ptr<scalar_t>(),
1596           grad_output.const_data_ptr<scalar_t>(),
1597           mean.const_data_ptr<accscalar_t>(),
1598           inv_std.const_data_ptr<accscalar_t>(),
1599           sumn_dy.mutable_data_ptr<accscalar_t>(),
1600           sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1601           weight.defined() ? grad_weight.mutable_data_ptr<scalar_t>() : nullptr,
1602           weight.defined() ? grad_bias.mutable_data_ptr<scalar_t>() : nullptr,
1603           staging_data_ptr,
1604           semaphores_ptr,
1605           reduction_size,
1606           stride);
1607       C10_CUDA_KERNEL_LAUNCH_CHECK();
1608     });
1609   }
1610 
1611   return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
1612 }
1613 
batch_norm_backward_elemt_channels_last_cuda_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const at::Tensor & sum_dy,const at::Tensor & sum_dy_xmu,const at::Tensor & count)1614 at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1615     const at::Tensor& grad_output,
1616     const at::Tensor& input,
1617     const at::Tensor& mean,
1618     const at::Tensor& inv_std,
1619     const at::Tensor& weight,
1620     const at::Tensor& sum_dy,
1621     const at::Tensor& sum_dy_xmu,
1622     const at::Tensor& count) {
1623   const auto stride = input.sizes()[1];
1624   const auto reduction_size = input.numel() / stride;
1625 
1626   // Input is guarunteed to be channels-last compatible
1627   at::Tensor grad_input = at::empty_like(input);
1628 
1629   dim3 block;
1630   dim3 grid;
1631   flexible_launch_configs(reduction_size, stride, block, grid);
1632 
1633   auto stream = at::cuda::getCurrentCUDAStream();
1634 
1635   if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1636     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1637       using accscalar_t = at::acc_type<scalar_t, true>;
1638       batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1639           <<<grid, block, 0, stream>>>(
1640           grad_output.const_data_ptr<scalar_t>(),
1641           input.const_data_ptr<scalar_t>(),
1642           mean.const_data_ptr<accscalar_t>(),
1643           inv_std.const_data_ptr<accscalar_t>(),
1644           weight.const_data_ptr<accscalar_t>(),
1645           sum_dy.const_data_ptr<accscalar_t>(),
1646           sum_dy_xmu.const_data_ptr<accscalar_t>(),
1647           count.const_data_ptr<int>(),
1648           grad_input.mutable_data_ptr<scalar_t>(),
1649           count.numel(),
1650           reduction_size,
1651           stride);
1652       C10_CUDA_KERNEL_LAUNCH_CHECK();
1653     });
1654   } else {
1655     if (weight.defined()) {
1656       TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
1657         " is not supported with weight.scalar_type() ", weight.scalar_type());
1658     }
1659     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1660       using accscalar_t = at::acc_type<scalar_t, true>;
1661       batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1662           <<<grid, block, 0, stream>>>(
1663           grad_output.const_data_ptr<scalar_t>(),
1664           input.const_data_ptr<scalar_t>(),
1665           mean.const_data_ptr<accscalar_t>(),
1666           inv_std.const_data_ptr<accscalar_t>(),
1667           weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1668           sum_dy.const_data_ptr<accscalar_t>(),
1669           sum_dy_xmu.const_data_ptr<accscalar_t>(),
1670           count.const_data_ptr<int>(),
1671           grad_input.mutable_data_ptr<scalar_t>(),
1672           count.numel(),
1673           reduction_size,
1674           stride);
1675       C10_CUDA_KERNEL_LAUNCH_CHECK();
1676     });
1677   }
1678 
1679   return grad_input;
1680 }
1681 
batch_norm_backward_elemt_channels_last_cuda_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const at::Tensor & sum_dy,const at::Tensor & sum_dy_xmu)1682 at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1683     const at::Tensor& grad_output,
1684     const at::Tensor& input,
1685     const at::Tensor& mean,
1686     const at::Tensor& inv_std,
1687     const at::Tensor& weight,
1688     const at::Tensor& sum_dy,
1689     const at::Tensor& sum_dy_xmu) {
1690   const auto stride = input.sizes()[1];
1691   const auto reduction_size = input.numel() / stride;
1692   auto norm_fct = 1.0 / reduction_size;
1693 
1694   // Input is guarunteed to be channels-last compatible
1695   at::Tensor grad_input = at::empty_like(input);
1696 
1697   dim3 block;
1698   dim3 grid;
1699   flexible_launch_configs(reduction_size, stride, block, grid);
1700 
1701   auto stream = at::cuda::getCurrentCUDAStream();
1702 
1703   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1704     using accscalar_t = at::acc_type<scalar_t, true>;
1705 
1706     if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1707       batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1708           <<<grid, block, 0, stream>>>(
1709           grad_output.const_data_ptr<scalar_t>(),
1710           input.const_data_ptr<scalar_t>(),
1711           mean.const_data_ptr<accscalar_t>(),
1712           inv_std.const_data_ptr<accscalar_t>(),
1713           weight.const_data_ptr<accscalar_t>(),
1714           sum_dy.const_data_ptr<accscalar_t>(),
1715           sum_dy_xmu.const_data_ptr<accscalar_t>(),
1716           grad_input.mutable_data_ptr<scalar_t>(),
1717           static_cast<accscalar_t>(norm_fct),
1718           reduction_size,
1719           stride);
1720           C10_CUDA_KERNEL_LAUNCH_CHECK();
1721     } else {
1722       batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1723           <<<grid, block, 0, stream>>>(
1724           grad_output.const_data_ptr<scalar_t>(),
1725           input.const_data_ptr<scalar_t>(),
1726           mean.const_data_ptr<accscalar_t>(),
1727           inv_std.const_data_ptr<accscalar_t>(),
1728           weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1729           sum_dy.const_data_ptr<accscalar_t>(),
1730           sum_dy_xmu.const_data_ptr<accscalar_t>(),
1731           grad_input.mutable_data_ptr<scalar_t>(),
1732           static_cast<accscalar_t>(norm_fct),
1733           reduction_size,
1734           stride);
1735           C10_CUDA_KERNEL_LAUNCH_CHECK();
1736     }
1737   });
1738 
1739   return grad_input;
1740 }
1741 
1742 } } // namespace at::native
1743