1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/DeviceUtils.cuh>
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/cuda/DeviceSqrt.cuh>
11 #include <ATen/native/cuda/LaunchUtils.h>
12 #include <c10/macros/Macros.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21
22 namespace at { namespace native {
23
24 // The maximum number of threads in a block
25 #if defined(USE_ROCM)
26 constexpr int MAX_BLOCK_SIZE = 256;
27 #else
28 constexpr int MAX_BLOCK_SIZE = 512;
29 #endif
30
31 constexpr unsigned MAX_GRID_SIZE = 65535u;
32
33 // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
getNumThreads(int nElem)34 static int getNumThreads(int nElem) {
35 #if defined(USE_ROCM)
36 int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
37 #else
38 int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
39 #endif
40 for (int i = 0; i != 5; ++i) {
41 if (nElem <= threadSizes[i]) {
42 return threadSizes[i];
43 }
44 }
45 return MAX_BLOCK_SIZE;
46 }
47
48 // Returns the index of the most significant 1 bit in `val`.
getMSB(int val)49 __device__ __forceinline__ int getMSB(int val) {
50 return 31 - __clz(val);
51 }
52
53 template <typename scalar_t, typename accscalar_t>
54 struct Float2 {
55 accscalar_t v1, v2;
Float2at::native::Float256 __device__ Float2() {}
Float2at::native::Float257 __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
Float2at::native::Float258 __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
operator +=at::native::Float259 __device__ Float2& operator+=(const Float2& a) {
60 v1 += a.v1;
61 v2 += a.v2;
62 return *this;
63 }
operator +(Float2 a,const Float2 & b)64 __device__ friend Float2 operator+(Float2 a, const Float2& b) {
65 a += b;
66 return a;
67 }
68 };
69
70 template <typename scalar_t, typename accscalar_t, typename PTA>
71 struct GradOp {
GradOpat::native::GradOp72 __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
73 : mean(m), input(i), grad_output(g) {}
operator ()at::native::GradOp74 __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
75 accscalar_t g = grad_output[batch][plane][n];
76 accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
77 return Float2<scalar_t, accscalar_t>(g, g * c);
78 }
79 const accscalar_t mean;
80 const PTA& input;
81 const PTA& grad_output;
82 };
83
84 template <typename acc_t>
85 struct SumReduceOp {
combineat::native::SumReduceOp86 __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
87
warp_shfl_downat::native::SumReduceOp88 __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
89 return WARP_SHFL_DOWN(data, offset);
90 }
91 };
92
93 template <typename scalar_t, typename accscalar_t>
94 struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
95 using acc_t = Float2<scalar_t, accscalar_t>;
96
combineat::native::SumReduceOp97 __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
98
warp_shfl_downat::native::SumReduceOp99 __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
100 return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
101 }
102 };
103
104 // Sum across (batch, x/y/z) applying Op() pointwise
105 // this works by first having each thread sum it's part
106 // of the data. Then there is a double-shuffling reduction.
107 // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
108 // data to the "warp leader", who writes its value into shared memory.
109 // Then a single warp reads the remaining (at most C10_WARP_SIZE) items
110 // and reduces them using another warpSum.
111 // The implicit assumption is that there are no more
112 // than C10_WARP_SIZE**2 threads.
113 template<typename scalar_t, typename Op, typename PTA>
reduce(Op op,PTA tensor,int plane)114 __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115 // first the reductions each thread does separately
116 scalar_t sum = static_cast<scalar_t>(0);
117 for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118 for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119 sum += op(batch, plane, x);
120 }
121 }
122 __shared__ scalar_t shared[C10_WARP_SIZE];
123 SumReduceOp<scalar_t> reduce_op;
124 sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
125 if (threadIdx.x == 0 && threadIdx.y == 0) {
126 shared[0] = sum;
127 }
128 __syncthreads();
129 // Everyone picks it up, should be broadcast into the whole grad_input
130 return shared[0];
131 }
132
133 constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
134 constexpr int ELEMENTS_PER_THREAD = 16;
135 constexpr int OPTIMAL_TILE_W = 32;
136 constexpr int MAX_H_BLOCK = 128;
137
flexible_launch_configs(const int reduction,const int stride,dim3 & block,dim3 & grid,const bool coop_flag=false)138 __host__ void flexible_launch_configs(
139 const int reduction,
140 const int stride,
141 dim3 &block,
142 dim3 &grid,
143 const bool coop_flag = false) {
144 int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
145 int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
146 MAX_BLOCK_SIZE / block_x);
147 if (block_x * block_y != MAX_BLOCK_SIZE) {
148 block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
149 }
150
151 int grid_x = at::ceil_div(stride, block_x);
152 int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
153 if (coop_flag) {
154 // it's not worth having a grid reduction if the reduction dimension is not big enough
155 grid_y = grid_y < 8 ? 1 : grid_y;
156 }
157
158 block.x = block_x;
159 block.y = block_y;
160 block.z = 1;
161 grid.x = grid_x;
162 grid.y = grid_y;
163 grid.z = 1;
164 }
165
166 template<typename T, typename C>
welford_merge_element(C & count,T & mean,T & m2n,const C & count_new,const T & mean_new,const T & m2n_new)167 __device__ __forceinline__ void welford_merge_element(C& count,
168 T& mean,
169 T& m2n,
170 const C& count_new,
171 const T& mean_new,
172 const T& m2n_new) {
173 T factor = T(1.0) / ::max(1, (count + count_new));
174 T delta0 = mean - mean_new;
175 mean = (mean_new * count_new + mean * count) * factor;
176 m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
177 count += count_new;
178 }
179
180 // merge mean/m2n among threadIdx.y within block
181 template<typename T, typename C>
welford_merge_block_vertical(C & count,T & mean,T & m2n,C * shmem_count,T * shmem_mean,T * shmem_m2n)182 __device__ __forceinline__ void welford_merge_block_vertical(C& count,
183 T& mean,
184 T& m2n,
185 C* shmem_count,
186 T* shmem_mean,
187 T* shmem_m2n) {
188 // write to shared memory
189 auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
190
191 #pragma unroll
192 for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
193 if (threadIdx.y < offset*2) {
194 shmem_mean[address_base] = mean;
195 shmem_m2n[address_base] = m2n;
196 shmem_count[address_base] = count;
197 }
198 __syncthreads();
199 if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
200 auto address = address_base + offset * blockDim.x;
201 // read shared memory back to register for reduction
202 auto count_new = shmem_count[address];
203 auto mean_new = shmem_mean[address];
204 auto m2n_new = shmem_m2n[address];
205
206 welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
207 }
208 }
209 }
210
211 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
batch_norm_transform_input_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,RestrictPtrTraits,index_t> input,GenericPackedTensorAccessor<input_scalar_t,3,RestrictPtrTraits,index_t> output,const GenericPackedTensorAccessor<typename std::conditional<train,stat_accscalar_t,stat_scalar_t>::type,1,RestrictPtrTraits,index_t> mean_,const GenericPackedTensorAccessor<typename std::conditional<train,stat_accscalar_t,stat_scalar_t>::type,1,RestrictPtrTraits,index_t> var_or_invstd,const GenericPackedTensorAccessor<const stat_scalar_t,1,RestrictPtrTraits,index_t> weight,const GenericPackedTensorAccessor<const stat_scalar_t,1,RestrictPtrTraits,index_t> bias,stat_accscalar_t epsilon)212 __global__ void batch_norm_transform_input_kernel(
213 const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
214 GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
215 const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
216 const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
217 const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
218 const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
219 stat_accscalar_t epsilon) {
220
221 index_t plane = blockIdx.x;
222
223 if (plane >= input.size(1)) {
224 return;
225 }
226
227 stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
228 stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
229 stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
230 stat_accscalar_t invstd;
231 if (train) {
232 invstd = var_or_invstd[plane];
233 } else {
234 invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
235 }
236
237 index_t bs = input.size(0);
238 index_t fs = input.size(2);
239
240 index_t bstep = blockDim.y * gridDim.y;
241 for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
242 auto o = output[batch][plane];
243 auto i = input[batch][plane];
244 for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
245 o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
246 }
247 }
248 }
249
250 struct InvStd {
251 template <typename T>
operator ()at::native::InvStd252 __device__ __forceinline__ T operator()(T var, double epsilon) const {
253 T invstd = 0;
254 if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
255 invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
256 }
257 return invstd;
258 }
259 };
260
261 struct Var {
262 template <typename T>
operator ()at::native::Var263 __device__ __forceinline__ T operator()(T var, double epsilon) const {
264 return var;
265 }
266 };
267
268 template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_collect_statistics_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,RestrictPtrTraits,index_t> input,const stat_accscalar_t epsilon,const stat_accscalar_t momentum,GenericPackedTensorAccessor<stat_accscalar_t,1,RestrictPtrTraits,index_t> save_mean,GenericPackedTensorAccessor<stat_accscalar_t,1,RestrictPtrTraits,index_t> save_transformed_var)269 __global__ void batch_norm_collect_statistics_kernel(
270 const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
271 const stat_accscalar_t epsilon,
272 const stat_accscalar_t momentum,
273 GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
274 GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
275
276 __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
277
278 int plane = blockIdx.x;
279 int N = input.size(0) * input.size(2);
280 int tid = threadIdx.x + threadIdx.y * blockDim.x;
281
282 // Compute the mean and variance across (batch, x/y/z)
283 // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
284 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
285 // and the parallel algorithm on the same page.
286 // We use two shuffles to reduce across the entire block.
287 // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
288 stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
289
290 // first the reductions each thread does separately
291 stat_accscalar_t avg = 0;
292 stat_accscalar_t var_n = 0;
293 int n = 0;
294 for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
295 for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
296 stat_accscalar_t v = input[batch][plane][x];
297 stat_accscalar_t d1 = v - avg;
298 n++;
299 avg += d1 / n;
300 var_n += d1 * (v - avg);
301 }
302 }
303
304 // first warpSum to get one value per thread to
305 // one value per warp
306 for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
307 stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
308 int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
309 stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
310 var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
311 avg = (n * avg + o_n * o_avg) * factor;
312 n += o_n;
313 }
314
315 // this writes each warps item into shared memory
316 // there are at most C10_WARP_SIZE items left because
317 // there are at most C10_WARP_SIZE**2 threads at the beginning
318 __syncthreads();
319 if (tid % C10_WARP_SIZE == 0) {
320 shared_n[tid / C10_WARP_SIZE] = n;
321 shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
322 shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
323 }
324 __syncthreads();
325 // now have a second warpSum to reduce the intermediate values
326 // from shared memory to a single number. The very first
327 // thread writes it to shared memory.
328
329 if (tid < C10_WARP_SIZE) {
330 n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
331 avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
332 var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
333 }
334 for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
335 stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
336 int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
337 stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
338 var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
339 avg = (n * avg + o_n * o_avg) * factor;
340 n += o_n;
341 }
342
343 // Save the mean, variance, and moving averages
344 if (tid == 0) {
345 if (save_mean.data() != NULL) {
346 save_mean[plane] = avg;
347 }
348 if (save_transformed_var.data() != NULL) {
349 save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
350 }
351 }
352
353 }
354
355 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_kernel(const GenericPackedTensorAccessor<const input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<const input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_weight,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_bias,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> running_mean,const GenericPackedTensorAccessor<const stat_scalar_t,1,DefaultPtrTraits,index_t> running_var,const GenericPackedTensorAccessor<const stat_accscalar_t,1,DefaultPtrTraits,index_t> save_mean,const GenericPackedTensorAccessor<const stat_accscalar_t,1,DefaultPtrTraits,index_t> save_invstd,bool train,stat_accscalar_t epsilon)356 __global__ void batch_norm_backward_kernel(
357 const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> input,
358 const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
359 GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
360 GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
361 GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
362 const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
363 const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
364 const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
365 const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
366 const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
367 bool train,
368 stat_accscalar_t epsilon) {
369
370 index_t plane = blockIdx.x;
371 index_t N = grad_output.size(0) * grad_output.size(2);
372
373 stat_accscalar_t mean, invstd;
374 if (train) {
375 mean = save_mean[plane];
376 invstd = save_invstd[plane];
377 } else {
378 mean = static_cast<stat_accscalar_t>(running_mean[plane]);
379 invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
380 }
381
382 stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
383 stat_accscalar_t norm = stat_accscalar_t(1) / N;
384
385 // Compute two values across (batch, x/y/z) in one pass:
386 // 1. Sum(grad_output)
387 // 2. DotProduct(input - mean, grad_output)
388 GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
389 auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
390
391 stat_accscalar_t grad_output_sum = res.v1;
392 stat_accscalar_t dot_p = res.v2;
393
394 stat_accscalar_t grad_mean = grad_output_sum * norm;
395 stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
396 stat_accscalar_t grad_scale = invstd * weight_val;
397
398 if (grad_input.data() != NULL) {
399 for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
400 for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
401 input_scalar_t go = grad_output[batch][plane][x];
402 if (train) {
403 stat_accscalar_t inp = input[batch][plane][x];
404 stat_accscalar_t proj = (inp - mean) * proj_scale;
405 grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
406 } else {
407 grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
408 }
409 }
410 }
411 }
412
413 if (grad_weight.size(0) > 0) {
414 if (threadIdx.x == 0) {
415 grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
416 }
417 }
418
419 if (grad_bias.size(0) > 0) {
420 if (threadIdx.x == 0) {
421 grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
422 }
423 }
424 }
425
426 template <typename scalar_t, typename accscalar_t, typename index_t>
batch_norm_reduce_statistics_kernel(const GenericPackedTensorAccessor<accscalar_t,2,RestrictPtrTraits,index_t> vec_mean,const GenericPackedTensorAccessor<accscalar_t,2,RestrictPtrTraits,index_t> vec_invstd,GenericPackedTensorAccessor<accscalar_t,1,RestrictPtrTraits,index_t> mean,GenericPackedTensorAccessor<accscalar_t,1,RestrictPtrTraits,index_t> invstd,GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> running_mean,GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> running_var,const accscalar_t epsilon,const accscalar_t momentum,const GenericPackedTensorAccessor<scalar_t,1,RestrictPtrTraits,index_t> counts)427 __global__ void batch_norm_reduce_statistics_kernel(
428 const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
429 const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
430 GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
431 GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
432 GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
433 GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
434 const accscalar_t epsilon,
435 const accscalar_t momentum,
436 const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
437
438 int feature_size = vec_mean.size(1);
439 int world_size = vec_mean.size(0);
440
441 int bid = blockIdx.x;
442 int tid = threadIdx.x;
443
444 // first the reductions each thread does separately
445 for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
446 accscalar_t avg = 0;
447 accscalar_t var_n = 0;
448 index_t n = 0;
449 for (int j = 0; j < world_size; j++) {
450 scalar_t count = counts[j];
451 accscalar_t m = vec_mean[j][i];
452 accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
453 v = (v * v - epsilon) * count;
454 accscalar_t factor = 1.0 / (n + count);
455 var_n += v + (avg - m) * (avg - m) * n * count * factor;
456 avg = n * factor * avg + count * factor * m;
457 n += count;
458 }
459 mean[i] = avg;
460 invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
461 if (running_mean.data() != NULL) {
462 running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
463 }
464 accscalar_t unbiasedVar = var_n / (n - 1);
465 if (running_var.data() != NULL) {
466 running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
467 }
468 }
469
470 }
471
472 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_reduce_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_weight,GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> grad_bias)473 __global__ void batch_norm_backward_reduce_kernel(
474 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
475 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
476 GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
477 GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
478 GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
479 GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
480 GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
481 GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
482
483 index_t plane = blockIdx.x;
484
485 stat_accscalar_t r_mean = mean[plane];
486 stat_accscalar_t factor = invstd[plane];
487
488 GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
489 auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
490
491 if (threadIdx.x == 0) {
492 if (grad_weight.size(0) > 0) {
493 grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
494 }
495 if (grad_bias.size(0) > 0) {
496 grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
497 }
498 if (sum_dy.size(0) > 0) {
499 sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
500 }
501 if (sum_dy_xmu.size(0) > 0) {
502 sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
503 }
504 }
505 }
506
507 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel_impl(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const stat_accscalar_t norm_fct)508 __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
509 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
510 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
511 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
512 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
513 const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
514 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
515 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
516 GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
517 const stat_accscalar_t norm_fct) {
518 index_t plane = blockIdx.x;
519
520 if (plane >= input.size(1)) {
521 return;
522 }
523
524 stat_accscalar_t m_c = mean[plane];
525 stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
526 stat_accscalar_t factor_1_c = invstd[plane];
527 stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
528 factor_2_c *= factor_1_c;
529 factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
530
531 index_t bs = input.size(0);
532 index_t fs = input.size(2);
533
534 index_t bstep = blockDim.y * gridDim.y;
535 for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
536 auto g_i = grad_input[batch][plane];
537 auto g_o = grad_output[batch][plane];
538 auto i = input[batch][plane];
539 for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
540 g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
541 }
542 }
543 }
544
545 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const int * __restrict__ numel,const int world_size)546 __global__ void batch_norm_backward_elemt_kernel(
547 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
548 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
549 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
550 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
551 const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
552 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
553 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
554 GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
555 const int* __restrict__ numel, const int world_size) {
556 int64_t total_numel = 0;
557 for (int i = 0; i < world_size; i ++) {
558 total_numel += numel[i];
559 }
560
561 const stat_accscalar_t norm_fct =
562 static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
563 batch_norm_backward_elemt_kernel_impl(
564 input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
565 }
566
567 template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
batch_norm_backward_elemt_kernel(const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> input,const GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_output,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> mean,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> invstd,const GenericPackedTensorAccessor<stat_scalar_t,1,DefaultPtrTraits,index_t> weight,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy,const GenericPackedTensorAccessor<stat_accscalar_t,1,DefaultPtrTraits,index_t> sum_dy_xmu,GenericPackedTensorAccessor<input_scalar_t,3,DefaultPtrTraits,index_t> grad_input,const stat_accscalar_t norm_fct)568 __global__ void batch_norm_backward_elemt_kernel(
569 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
570 const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
571 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
572 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
573 const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
574 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
575 const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
576 GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
577 const stat_accscalar_t norm_fct) {
578 batch_norm_backward_elemt_kernel_impl(
579 input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
580 }
581
582 template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
get_packed_accessor(const Tensor & t,c10::string_view var_name)583 static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
584 const Tensor& t, c10::string_view var_name) {
585 constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
586 const auto actual_type = t.scalar_type();
587 TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
588 " to have type ", expect_type, " but got ", actual_type);
589 return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
590 }
591
592 template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
packed_accessor_or_dummy(const Tensor & t,c10::string_view var_name)593 static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
594 const Tensor& t, c10::string_view var_name) {
595 if (!t.defined()) {
596 const std::array<index_t, dim> zeros{{0}};
597 return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
598 }
599 return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
600 }
601
602 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & weight_,const Tensor & running_mean_,const Tensor & running_var_,const Tensor & save_mean_,const Tensor & save_invstd_,bool train,double epsilon,std::array<bool,3> grad_input_mask)603 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
604 const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
605 bool train, double epsilon, std::array<bool,3> grad_input_mask) {
606
607 using accscalar_t = at::acc_type<stat_scalar_t, true>;
608 Tensor grad_input_;
609 Tensor grad_input_reshaped;
610 Tensor grad_weight_;
611 Tensor grad_bias_;
612 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
613 auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
614
615 if (grad_input_mask[0]) {
616 grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
617 grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
618 }
619 if (grad_input_mask[1]) {
620 grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
621 }
622 if (grad_input_mask[2]) {
623 grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
624 }
625
626 auto input = get_packed_accessor<
627 const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
628 auto grad_output = get_packed_accessor<
629 const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
630 auto grad_input = packed_accessor_or_dummy<
631 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
632 auto weight = packed_accessor_or_dummy<
633 const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
634 auto grad_weight = packed_accessor_or_dummy<
635 stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
636 auto grad_bias = packed_accessor_or_dummy<
637 stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
638 auto running_mean = packed_accessor_or_dummy<
639 const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
640 auto running_var = packed_accessor_or_dummy<
641 const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
642 auto save_mean = packed_accessor_or_dummy<
643 const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
644 auto save_invstd = packed_accessor_or_dummy<
645 const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
646
647 auto stream = at::cuda::getCurrentCUDAStream();
648 dim3 blocks(input.size(1));
649 int tf = getNumThreads(input.size(2));
650 dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
651
652 batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
653 (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
654 save_mean, save_invstd, train, epsilon);
655 C10_CUDA_KERNEL_LAUNCH_CHECK();
656
657 return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
658 }
659
660 template<typename scalar_t, typename index_t, typename VarTransform>
batch_norm_stats_cuda_template(const Tensor & out_mean,const Tensor & out_invstd,const Tensor & input_,double epsilon)661 void batch_norm_stats_cuda_template(
662 const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
663
664 using accscalar_t = at::acc_type<scalar_t, true>;
665 int64_t n_input = input_.size(1);
666 Tensor dummy_mean_;
667 Tensor dummy_var_;
668 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
669
670 resize_output(out_mean, {n_input});
671 resize_output(out_invstd, {n_input});
672 auto input = get_packed_accessor<
673 const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
674 TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
675 out_invstd.sizes()[0]);
676 TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
677 out_mean.sizes()[0]);
678
679 auto mean = packed_accessor_or_dummy<
680 accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
681 auto invstd = packed_accessor_or_dummy<
682 accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
683 auto stream = at::cuda::getCurrentCUDAStream();
684
685 dim3 blocks(input.size(1));
686 int tf = getNumThreads(input.size(2));
687 dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
688 batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
689 (input, epsilon, 0.0, mean, invstd);
690 C10_CUDA_KERNEL_LAUNCH_CHECK();
691 }
692
693 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_elemt_cuda_template(const Tensor & output_,const Tensor & input_,const Tensor & weight_,const Tensor & bias_,const Tensor & mean_,const Tensor & invstd_)694 void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
695 const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
696
697 using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
698 int64_t n_input = input_.size(1);
699 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
700 auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
701
702 auto input = get_packed_accessor<
703 const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
704 auto output = get_packed_accessor<
705 input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
706 auto weight = packed_accessor_or_dummy<
707 const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
708 auto bias = packed_accessor_or_dummy<
709 const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
710 auto mean = packed_accessor_or_dummy<
711 stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
712 auto invstd = packed_accessor_or_dummy<
713 stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
714 auto stream = at::cuda::getCurrentCUDAStream();
715
716 // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
717 const double dummy_epsilon = 1e-5;
718
719 // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
720 // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
721 // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
722 // The various planes are independent, so we use blocks for them.
723 int tf = std::max<int>(getNumThreads(input.size(2)/4),
724 std::min<int>(getNumThreads(input.size(2)), 64));
725 int tb = std::max<int>(64/tf, 1);
726 dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
727 (input.size(0)+tb-1)/tb)));
728 blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
729 dim3 threads_trans(tf, tb);
730 batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
731 (input, output, mean, invstd, weight, bias, dummy_epsilon);
732 C10_CUDA_KERNEL_LAUNCH_CHECK();
733 }
734
735 template<typename scalar_t, typename accscalar_t, typename index_t>
batch_norm_gather_stats_cuda_template(const Tensor & mean_,const Tensor & invstd_,const Tensor & running_mean_,const Tensor & running_var_,double momentum,double epsilon,const Tensor & counts_)736 std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
737 const Tensor& running_mean_, const Tensor& running_var_,
738 double momentum, double epsilon, const Tensor& counts_) {
739
740 Tensor save_mean_;
741 Tensor save_invstd_;
742
743 auto features = mean_.size(1);
744 auto input_options = mean_.options();
745 if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
746 input_options = input_options.dtype(ScalarType::Float);
747 }
748 save_mean_ = at::empty({features}, input_options);
749 save_invstd_ = at::empty({features}, input_options);
750
751 auto mean = packed_accessor_or_dummy<
752 accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
753 auto invstd = packed_accessor_or_dummy<
754 accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
755 auto running_mean = packed_accessor_or_dummy<
756 scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
757 auto running_var = packed_accessor_or_dummy<
758 scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
759 auto counts = packed_accessor_or_dummy<
760 scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
761
762 auto save_mean = get_packed_accessor<
763 accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
764 auto save_invstd = get_packed_accessor<
765 accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
766 auto stream = at::cuda::getCurrentCUDAStream();
767
768 int block = getNumThreads(features);
769 int grid = std::max<int>(1, features/block);
770 batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
771 (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
772 C10_CUDA_KERNEL_LAUNCH_CHECK();
773
774 return std::make_tuple(save_mean_, save_invstd_);
775 }
776
777 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_reduce_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const bool input_g,const bool weight_g,const bool bias_g)778 std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
779 const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
780 const bool input_g, const bool weight_g, const bool bias_g) {
781
782 using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
783 int64_t n_input = input_.size(1);
784 Tensor sum_dy_;
785 Tensor sum_dy_xmu_;
786 Tensor grad_weight_;
787 Tensor grad_bias_;
788 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
789 auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
790
791 if (input_g) {
792 sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
793 sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
794 }
795 if (weight_g) {
796 grad_weight_ = at::empty({n_input}, weight_.options());
797 }
798 if (bias_g) {
799 grad_bias_ = at::empty({n_input}, weight_.options());
800 }
801
802 auto input = get_packed_accessor<
803 input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
804 auto grad_output = get_packed_accessor<
805 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
806 auto grad_weight = packed_accessor_or_dummy<
807 stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
808 auto grad_bias = packed_accessor_or_dummy<
809 stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
810 auto mean = packed_accessor_or_dummy<
811 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
812 auto invstd = packed_accessor_or_dummy<
813 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
814 auto sum_dy = packed_accessor_or_dummy<
815 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
816 auto sum_dy_xmu = packed_accessor_or_dummy<
817 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
818
819 auto batch_size = input_reshaped.size(0);
820 auto feature_size = input_reshaped.size(2);
821 auto stream = at::cuda::getCurrentCUDAStream();
822
823 int warp_size = at::cuda::warp_size();
824 int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
825 // We want block_x to be at least a warp width
826 int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
827 const dim3 block(block_x, block_y);
828 const dim3 grid(n_input);
829
830 batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
831 (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
832 C10_CUDA_KERNEL_LAUNCH_CHECK();
833
834 return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
835 }
836
837 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_elemt_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const Tensor & sum_dy_,const Tensor & sum_dy_xmu_)838 Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
839 const Tensor& mean_, const Tensor& invstd_,
840 const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
841
842 using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
843 int64_t n_input = input_.size(1);
844 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
845 auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
846 auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
847
848 auto input = get_packed_accessor<
849 input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
850 auto grad_input = get_packed_accessor<
851 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
852 auto grad_output = get_packed_accessor<
853 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
854 auto mean = packed_accessor_or_dummy<
855 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
856 auto invstd = packed_accessor_or_dummy<
857 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
858 auto weight = packed_accessor_or_dummy<
859 stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
860 auto sum_dy = packed_accessor_or_dummy<
861 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
862 auto sum_dy_xmu = packed_accessor_or_dummy<
863 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
864
865 auto stream = at::cuda::getCurrentCUDAStream();
866
867 // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
868 // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
869 // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
870 // The various planes are independent, so we use blocks for them.
871 int tf = std::max<int>(getNumThreads(input.size(2)/4),
872 std::min<int>(getNumThreads(input.size(2)), 64));
873 int tb = std::max<int>(64/tf, 1);
874 dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
875 (input.size(0)+tb-1)/tb)));
876 blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
877 dim3 threads_trans(tf, tb);
878 auto reduction_size = input_.numel() / n_input;
879 auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
880 batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
881 <<<blocks_trans, threads_trans, 0, stream>>>
882 (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
883 C10_CUDA_KERNEL_LAUNCH_CHECK();
884
885 return grad_input_reshaped.view(input_.sizes());
886 }
887
888 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
batch_norm_backward_elemt_cuda_template(const Tensor & grad_out_,const Tensor & input_,const Tensor & mean_,const Tensor & invstd_,const Tensor & weight_,const Tensor & sum_dy_,const Tensor & sum_dy_xmu_,const Tensor & count)889 Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
890 const Tensor& mean_, const Tensor& invstd_,
891 const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
892
893 using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
894 int64_t n_input = input_.size(1);
895 auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
896 auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
897 auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
898
899 auto input = get_packed_accessor<
900 input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
901 auto grad_input = get_packed_accessor<
902 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
903 auto grad_output = get_packed_accessor<
904 input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
905 auto mean = packed_accessor_or_dummy<
906 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
907 auto invstd = packed_accessor_or_dummy<
908 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
909 auto weight = packed_accessor_or_dummy<
910 stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
911 auto sum_dy = packed_accessor_or_dummy<
912 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
913 auto sum_dy_xmu = packed_accessor_or_dummy<
914 stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
915
916 auto stream = at::cuda::getCurrentCUDAStream();
917
918 // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
919 // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
920 // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
921 // The various planes are independent, so we use blocks for them.
922 int tf = std::max<int>(getNumThreads(input.size(2)/4),
923 std::min<int>(getNumThreads(input.size(2)), 64));
924 int tb = std::max<int>(64/tf, 1);
925 dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
926 (input.size(0)+tb-1)/tb)));
927 blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
928 dim3 threads_trans(tf, tb);
929 batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
930 (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr<int>(), count.numel());
931 C10_CUDA_KERNEL_LAUNCH_CHECK();
932
933 return grad_input_reshaped.view(input_.sizes());
934 }
935
936 // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
937 // original apex name: welford_kernel_c_last
938 template
939 <typename VarTransform,
940 typename scalar_t,
941 typename accscalar_t,
942 int PARALLEL_LOADS>
943 __global__ void
batch_norm_collect_statistics_channels_last_kernel(const scalar_t * __restrict__ input,accscalar_t * __restrict__ out_mean,accscalar_t * __restrict__ out_invstd,volatile accscalar_t * staging_data,int * semaphores,const int reduction_size,const int stride,accscalar_t epsilon)944 batch_norm_collect_statistics_channels_last_kernel(
945 const scalar_t* __restrict__ input,
946 accscalar_t* __restrict__ out_mean,
947 accscalar_t* __restrict__ out_invstd,
948 volatile accscalar_t* staging_data,
949 int* semaphores,
950 const int reduction_size,
951 const int stride,
952 accscalar_t epsilon) {
953 // hide latency with concurrency
954 accscalar_t x_mean[PARALLEL_LOADS];
955 accscalar_t m_2_n[PARALLEL_LOADS];
956 int count[PARALLEL_LOADS];
957
958 #pragma unroll
959 for (int i = 0; i < PARALLEL_LOADS; i++) {
960 x_mean[i] = accscalar_t(0);
961 m_2_n[i] = accscalar_t(0);
962 count[i] = accscalar_t(0);
963 }
964 // tensor dimension (m,c)
965
966 // loop along m dimension
967 int inner_loop_stride = blockDim.y * gridDim.y;
968
969 // offset along m dimension
970 int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
971 int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
972
973 int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
974 int address_base = m_offset * stride + c_offset;
975 int address_increment = inner_loop_stride * stride;
976
977 for (int i = 0; i < loop_count; i++) {
978 accscalar_t x_math[PARALLEL_LOADS];
979 accscalar_t x_count_inv[PARALLEL_LOADS];
980 accscalar_t is_valid[PARALLEL_LOADS];
981
982 // load multiple data in
983 #pragma unroll
984 for (int j = 0; j < PARALLEL_LOADS; j++) {
985 if (c_offset < stride && m_offset < reduction_size) {
986 x_math[j] = input[address_base];
987 count[j]++;
988 x_count_inv[j] = accscalar_t(1) / count[j];
989 is_valid[j] = accscalar_t(1);
990 } else {
991 x_math[j] = accscalar_t(0);
992 x_count_inv[j] = accscalar_t(0);
993 is_valid[j] = accscalar_t(0);
994 }
995 m_offset += inner_loop_stride;
996 address_base += address_increment;
997 }
998
999 // calculate mean/m2n with welford
1000 #pragma unroll
1001 for (int j = 0; j < PARALLEL_LOADS; j++) {
1002 accscalar_t delta0 = x_math[j] - x_mean[j];
1003 x_mean[j] += delta0 * x_count_inv[j];
1004 accscalar_t delta1 = x_math[j] - x_mean[j];
1005 m_2_n[j] += delta0 * delta1 * is_valid[j];
1006 }
1007 }
1008
1009 // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
1010 #pragma unroll
1011 for (int j = 1; j < PARALLEL_LOADS; j++) {
1012 welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
1013 }
1014
1015 // release x_mean / m_2_n
1016 auto mean_th = x_mean[0];
1017 auto m2_th = m_2_n[0];
1018 auto count_th = count[0];
1019
1020 // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1021 static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
1022 static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
1023 static __shared__ int shmem_count[MAX_BLOCK_SIZE];
1024
1025 welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1026
1027 if (gridDim.y > 1) {
1028 volatile accscalar_t* staging_mean = staging_data;
1029 volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
1030 volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
1031
1032 address_base = c_offset + blockIdx.y * stride;
1033 // write data to staging_data;
1034 if (threadIdx.y == 0 && c_offset < stride) {
1035 staging_mean[address_base] = mean_th;
1036 staging_m2n[address_base] = m2_th;
1037 staging_count[address_base] = count_th;
1038 }
1039
1040 __threadfence();
1041 __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1042
1043 __shared__ bool is_last_block_done;
1044 // mark block done
1045 if (threadIdx.x == 0 && threadIdx.y == 0) {
1046 int old = atomicAdd(&semaphores[blockIdx.x], 1);
1047 is_last_block_done = (old == (gridDim.y-1));
1048 }
1049
1050 __syncthreads();
1051
1052 // check that all data is now available in global memory
1053 if (is_last_block_done) {
1054 count_th = 0;
1055 mean_th = accscalar_t(0.0);
1056 m2_th = accscalar_t(0.0);
1057
1058 for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1059 address_base = c_offset + y * stride;
1060 int count_new = c_offset < stride ? staging_count[address_base] : 0;
1061 accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
1062 accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
1063
1064 welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
1065 }
1066
1067 welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1068 if (threadIdx.y == 0 && c_offset < stride) {
1069 out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1070 out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1071 }
1072 }
1073 } else {
1074 if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1075 out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1076 out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1077 }
1078 }
1079 }
1080
1081 // elementwise BN kernel
1082 // original apex name: batchnorm_forward_c_last_kernel
1083 template <
1084 typename scalar_t,
1085 typename accscalar_t,
1086 typename layerscalar_t,
1087 int PARALLEL_LOADS>
batch_norm_transform_input_channels_last_kernel(const scalar_t * __restrict__ input,const scalar_t * __restrict__ z,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const layerscalar_t * __restrict__ shift,scalar_t * __restrict__ out,const int reduction_size,const int stride,const bool fuse_relu)1088 __global__ void batch_norm_transform_input_channels_last_kernel(
1089 const scalar_t* __restrict__ input,
1090 const scalar_t* __restrict__ z,
1091 const accscalar_t* __restrict__ mean,
1092 const accscalar_t* __restrict__ inv_std,
1093 const layerscalar_t* __restrict__ weight,
1094 const layerscalar_t* __restrict__ shift,
1095 scalar_t* __restrict__ out,
1096 const int reduction_size,
1097 const int stride,
1098 const bool fuse_relu) {
1099 // tensor dimension (m,c)
1100 // loop along m dimension
1101 int inner_loop_stride = blockDim.y * gridDim.y;
1102
1103 // offset along m dimension
1104 int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1105 int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1106
1107 if (c_offset >= stride || m_offset >= reduction_size) {
1108 return;
1109 }
1110
1111 auto m_c = mean[c_offset];
1112 auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
1113 auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
1114 auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
1115
1116 int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1117 int address_base = m_offset * stride + c_offset;
1118 int address_increment = inner_loop_stride * stride;
1119
1120 for (int i = 0; i < loop_count; i++) {
1121 #pragma unroll
1122 for (int j = 0; j < PARALLEL_LOADS; j++) {
1123 if (c_offset < stride && m_offset < reduction_size) {
1124 auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
1125 if (z != nullptr) {
1126 tmp += z[address_base];
1127 }
1128 out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
1129 }
1130 m_offset += inner_loop_stride;
1131 address_base += address_increment;
1132 }
1133 }
1134 }
1135
1136 template<typename T>
merge_block_vertical_backward(T & sum_dy,T & sum_dy_xmu,T * shmem_sum_dy,T * shmem_sum_dy_xmu)1137 __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
1138 T& sum_dy_xmu,
1139 T* shmem_sum_dy,
1140 T* shmem_sum_dy_xmu) {
1141 // write to shared memory
1142 auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
1143
1144 #pragma unroll
1145 for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
1146 if (threadIdx.y < offset*2) {
1147 shmem_sum_dy[address_base] = sum_dy;
1148 shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
1149 }
1150 __syncthreads();
1151 if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
1152 auto address = address_base + offset * blockDim.x;
1153
1154 sum_dy += shmem_sum_dy[address];
1155 sum_dy_xmu += shmem_sum_dy_xmu[address];
1156 }
1157 }
1158 }
1159
1160 // batchnorm backward kernel for c last tensor
1161 // original apex name: reduce_bn_c_last_kernel
1162 template <
1163 int PARALLEL_LOADS,
1164 typename scalar_t,
1165 typename accscalar_t,
1166 typename layerscalar_t>
batch_norm_backward_reduce_channels_last_kernel(const scalar_t * __restrict__ input,const scalar_t * __restrict__ grad_output,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,accscalar_t * __restrict__ sum_dy_o,accscalar_t * __restrict__ sum_dy_xmu_o,layerscalar_t * __restrict__ grad_weight,layerscalar_t * __restrict__ grad_bias,volatile accscalar_t * staging_data,int * semaphores,const int reduction_size,const int stride)1167 __global__ void batch_norm_backward_reduce_channels_last_kernel(
1168 const scalar_t* __restrict__ input,
1169 const scalar_t* __restrict__ grad_output,
1170 const accscalar_t* __restrict__ mean,
1171 const accscalar_t* __restrict__ inv_std,
1172 accscalar_t* __restrict__ sum_dy_o,
1173 accscalar_t* __restrict__ sum_dy_xmu_o,
1174 layerscalar_t* __restrict__ grad_weight,
1175 layerscalar_t* __restrict__ grad_bias,
1176 volatile accscalar_t* staging_data,
1177 int* semaphores,
1178 const int reduction_size,
1179 const int stride) {
1180
1181 // hide latency with concurrency
1182 accscalar_t sum_dy[PARALLEL_LOADS];
1183 accscalar_t sum_dy_xmu[PARALLEL_LOADS];
1184
1185 #pragma unroll
1186 for (int i = 0; i < PARALLEL_LOADS; i++) {
1187 sum_dy[i] = accscalar_t(0);
1188 sum_dy_xmu[i] = accscalar_t(0);
1189 }
1190 // tensor dimension (m,c)
1191
1192 // loop along m dimension
1193 int inner_loop_stride = blockDim.y * gridDim.y;
1194
1195 // offset along m dimension
1196 int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1197 int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1198
1199 if (c_offset >= stride || m_offset >= reduction_size) {
1200 return;
1201 }
1202
1203 int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1204 int address_base = m_offset * stride + c_offset;
1205 int address_increment = inner_loop_stride * stride;
1206
1207 auto r_mean = mean[c_offset];
1208 auto factor = inv_std[c_offset];
1209
1210 for (int i = 0; i < loop_count; i++) {
1211 accscalar_t x_input[PARALLEL_LOADS];
1212 accscalar_t x_grad_output[PARALLEL_LOADS];
1213
1214 // load multiple data in
1215 #pragma unroll
1216 for (int j = 0; j < PARALLEL_LOADS; j++) {
1217 if (c_offset < stride && m_offset < reduction_size) {
1218 x_input[j] = input[address_base];
1219 x_grad_output[j] = grad_output[address_base];
1220 } else {
1221 x_input[j] = accscalar_t(0);
1222 x_grad_output[j] = accscalar_t(0);
1223 }
1224 m_offset += inner_loop_stride;
1225 address_base += address_increment;
1226 }
1227
1228 // calculate sum_dy / sum_dy_xmu
1229 #pragma unroll
1230 for (int j = 0; j < PARALLEL_LOADS; j++) {
1231 sum_dy[j] += x_grad_output[j];
1232 sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
1233 }
1234 }
1235
1236 // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
1237 #pragma unroll
1238 for (int j = 1; j < PARALLEL_LOADS; j++) {
1239 sum_dy[0] += sum_dy[j];
1240 sum_dy_xmu[0] += sum_dy_xmu[j];
1241 }
1242
1243 // release array of registers
1244 auto sum_dy_th = sum_dy[0];
1245 auto sum_dy_xmu_th = sum_dy_xmu[0];
1246
1247 // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1248 static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
1249 static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
1250
1251 merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1252
1253 if (gridDim.y > 1) {
1254 volatile accscalar_t* staging_sum_dy = staging_data;
1255 volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
1256
1257 address_base = c_offset + blockIdx.y * stride;
1258 // write data to staging_data;
1259 if (threadIdx.y == 0 && c_offset < stride) {
1260 staging_sum_dy[address_base] = sum_dy_th;
1261 staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
1262 }
1263
1264 __threadfence();
1265 __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1266
1267 __shared__ bool is_last_block_done;
1268 // mark block done
1269 if (threadIdx.x == 0 && threadIdx.y == 0) {
1270 int old = atomicAdd(&semaphores[blockIdx.x], 1);
1271 is_last_block_done = (old == (gridDim.y-1));
1272 }
1273
1274 __syncthreads();
1275
1276 // check that all data is now available in global memory
1277 if (is_last_block_done) {
1278 sum_dy_th = accscalar_t(0.0);
1279 sum_dy_xmu_th = accscalar_t(0.0);
1280
1281 for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1282 address_base = c_offset + y * stride;
1283 sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
1284 sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
1285 }
1286
1287 merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1288 if (threadIdx.y == 0 && c_offset < stride) {
1289 if (grad_bias != nullptr) {
1290 grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1291 }
1292 if (grad_weight != nullptr) {
1293 grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1294 }
1295 //mean_dy[c_offset] = sum_dy_th / reduction_size;
1296 //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1297 sum_dy_o[c_offset] = sum_dy_th;
1298 sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1299 }
1300 }
1301 } else {
1302 if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1303 if (grad_bias != nullptr) {
1304 grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1305 }
1306 if (grad_weight != nullptr) {
1307 grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1308 }
1309 //mean_dy[c_offset] = sum_dy_th / reduction_size;
1310 //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1311 sum_dy_o[c_offset] = sum_dy_th;
1312 sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1313 }
1314 }
1315 }
1316
1317 // elementwise BN kernel
1318 // original apex name: batchnorm_backward_c_last_kernel
1319 template <
1320 int PARALLEL_LOADS,
1321 typename scalar_t,
1322 typename accscalar_t,
1323 typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel_impl(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,scalar_t * __restrict__ grad_input,const accscalar_t norm_fct,const int reduction_size,const int stride)1324 __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
1325 const scalar_t* __restrict__ grad_output,
1326 const scalar_t* __restrict__ input,
1327 const accscalar_t* __restrict__ mean,
1328 const accscalar_t* __restrict__ inv_std,
1329 const layerscalar_t* __restrict__ weight,
1330 const accscalar_t* __restrict__ sum_dy,
1331 const accscalar_t* __restrict__ sum_dy_xmu,
1332 scalar_t* __restrict__ grad_input,
1333 const accscalar_t norm_fct,
1334 const int reduction_size,
1335 const int stride) {
1336 // tensor dimension (m,c)
1337 // loop along m dimension
1338 int inner_loop_stride = blockDim.y * gridDim.y;
1339
1340 // offset along m dimension
1341 int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1342 int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1343
1344 if (c_offset >= stride || m_offset >= reduction_size) {
1345 return;
1346 }
1347
1348 auto m_c = mean[c_offset];
1349 auto m_dy_c = sum_dy[c_offset] * norm_fct;
1350 auto factor_1_c = inv_std[c_offset];
1351 auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
1352 factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
1353
1354 int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1355 int address_base = m_offset * stride + c_offset;
1356 int address_increment = inner_loop_stride * stride;
1357
1358 for (int i = 0; i < loop_count; i++) {
1359 #pragma unroll
1360 for (int j = 0; j < PARALLEL_LOADS; j++) {
1361 if (c_offset < stride && m_offset < reduction_size) {
1362 grad_input[address_base] = static_cast<scalar_t>(
1363 (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
1364 (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
1365 * factor_2_c);
1366 }
1367 m_offset += inner_loop_stride;
1368 address_base += address_increment;
1369 }
1370 }
1371 }
1372
1373 template <
1374 int PARALLEL_LOADS,
1375 typename scalar_t,
1376 typename accscalar_t,
1377 typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,const int * __restrict__ numel,scalar_t * __restrict__ grad_input,const int64_t world_size,const int reduction_size,const int stride)1378 __global__ void batch_norm_backward_elemt_channels_last_kernel(
1379 const scalar_t* __restrict__ grad_output,
1380 const scalar_t* __restrict__ input,
1381 const accscalar_t* __restrict__ mean,
1382 const accscalar_t* __restrict__ inv_std,
1383 const layerscalar_t* __restrict__ weight,
1384 const accscalar_t* __restrict__ sum_dy,
1385 const accscalar_t* __restrict__ sum_dy_xmu,
1386 const int* __restrict__ numel,
1387 scalar_t* __restrict__ grad_input,
1388 const int64_t world_size,
1389 const int reduction_size,
1390 const int stride) {
1391
1392 int64_t total_numel = 0;
1393 for (int i = 0; i < world_size; i++) {
1394 total_numel += numel[i];
1395 }
1396
1397 auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
1398 batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1399 grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1400 grad_input, norm_fct, reduction_size, stride);
1401 }
1402
1403 template <
1404 int PARALLEL_LOADS,
1405 typename scalar_t,
1406 typename accscalar_t,
1407 typename layerscalar_t>
batch_norm_backward_elemt_channels_last_kernel(const scalar_t * __restrict__ grad_output,const scalar_t * __restrict__ input,const accscalar_t * __restrict__ mean,const accscalar_t * __restrict__ inv_std,const layerscalar_t * __restrict__ weight,const accscalar_t * __restrict__ sum_dy,const accscalar_t * __restrict__ sum_dy_xmu,scalar_t * __restrict__ grad_input,const accscalar_t norm_fct,const int reduction_size,const int stride)1408 __global__ void batch_norm_backward_elemt_channels_last_kernel(
1409 const scalar_t* __restrict__ grad_output,
1410 const scalar_t* __restrict__ input,
1411 const accscalar_t* __restrict__ mean,
1412 const accscalar_t* __restrict__ inv_std,
1413 const layerscalar_t* __restrict__ weight,
1414 const accscalar_t* __restrict__ sum_dy,
1415 const accscalar_t* __restrict__ sum_dy_xmu,
1416 scalar_t* __restrict__ grad_input,
1417 const accscalar_t norm_fct,
1418 const int reduction_size,
1419 const int stride) {
1420 batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1421 grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1422 grad_input, norm_fct, reduction_size, stride);
1423 }
1424
1425 template<typename scalar_t, typename VarTransform>
batch_norm_stats_channels_last_cuda_template(const Tensor & out_mean,const Tensor & out_invstd,const Tensor & input,double epsilon)1426 void batch_norm_stats_channels_last_cuda_template(
1427 const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
1428 using accscalar_t = at::acc_type<scalar_t, true>;
1429
1430 const auto stride = input.sizes()[1];
1431 const auto reduction_size = input.numel() / stride;
1432
1433 resize_output(out_mean, {stride});
1434 resize_output(out_invstd, {stride});
1435 TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
1436 out_invstd.sizes()[0]);
1437 TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
1438 out_mean.sizes()[0]);
1439
1440 dim3 block;
1441 dim3 grid;
1442 flexible_launch_configs(reduction_size, stride, block, grid, true);
1443
1444 at::Tensor staging_data;
1445 at::Tensor semaphores;
1446 if (grid.y > 1) {
1447 staging_data = at::empty({4*stride*grid.y}, out_mean.options());
1448 semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1449 }
1450
1451 accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1452 int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1453 batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
1454 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
1455 input.const_data_ptr<scalar_t>(),
1456 out_mean.mutable_data_ptr<accscalar_t>(),
1457 out_invstd.mutable_data_ptr<accscalar_t>(),
1458 staging_data_ptr,
1459 semaphores_ptr,
1460 reduction_size,
1461 stride,
1462 epsilon);
1463 C10_CUDA_KERNEL_LAUNCH_CHECK();
1464 }
1465
batch_norm_elemt_channels_last_cuda_template(const at::Tensor & output,const at::Tensor & input,const at::Tensor & weight,const at::Tensor & shift,const at::Tensor & mean,const at::Tensor & inv_std,const std::optional<at::Tensor> & z=std::nullopt,const bool fuse_relu=false)1466 void batch_norm_elemt_channels_last_cuda_template(
1467 const at::Tensor& output,
1468 const at::Tensor& input,
1469 const at::Tensor& weight,
1470 const at::Tensor& shift, // bias of BN
1471 const at::Tensor& mean,
1472 const at::Tensor& inv_std,
1473 const std::optional<at::Tensor>& z = std::nullopt, // bias after BN
1474 const bool fuse_relu = false) {
1475 const auto stride = input.sizes()[1];
1476 const auto reduction_size = input.numel() / stride;
1477
1478 dim3 block;
1479 dim3 grid;
1480 flexible_launch_configs(reduction_size, stride, block, grid);
1481
1482 auto stream = at::cuda::getCurrentCUDAStream();
1483 const auto second_dtype = weight.defined() ? weight.scalar_type() :
1484 (shift.defined() ? shift.scalar_type() : input.scalar_type());
1485
1486 if (input.scalar_type() != second_dtype) {
1487 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1488 using accscalar_t = at::acc_type<scalar_t, true>;
1489 batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1490 <<<grid, block, 0, stream>>>(
1491 input.const_data_ptr<scalar_t>(),
1492 z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1493 mean.const_data_ptr<accscalar_t>(),
1494 inv_std.const_data_ptr<accscalar_t>(),
1495 weight.defined() ? weight.const_data_ptr<accscalar_t>() : nullptr,
1496 shift.defined() ? shift.const_data_ptr<accscalar_t>() : nullptr,
1497 output.mutable_data_ptr<scalar_t>(),
1498 reduction_size,
1499 stride,
1500 fuse_relu);
1501 C10_CUDA_KERNEL_LAUNCH_CHECK();
1502 });
1503 } else {
1504 if (weight.defined()){
1505 TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
1506 " is not supported with weight.scalar_type() ", weight.scalar_type());
1507 }
1508 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1509 using accscalar_t = at::acc_type<scalar_t, true>;
1510 batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
1511 <<<grid, block, 0, stream>>>(
1512 input.const_data_ptr<scalar_t>(),
1513 z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1514 mean.const_data_ptr<accscalar_t>(),
1515 inv_std.const_data_ptr<accscalar_t>(),
1516 weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1517 shift.defined() ? shift.const_data_ptr<scalar_t>(): nullptr,
1518 output.mutable_data_ptr<scalar_t>(),
1519 reduction_size,
1520 stride,
1521 fuse_relu);
1522 C10_CUDA_KERNEL_LAUNCH_CHECK();
1523 });
1524 }
1525 }
1526
1527 std::tuple<Tensor, Tensor, Tensor, Tensor>
batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const bool input_g,const bool weight_g,const bool bias_g)1528 batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
1529 const at::Tensor& input,
1530 const at::Tensor& mean,
1531 const at::Tensor& inv_std,
1532 const at::Tensor& weight,
1533 const bool input_g, const bool weight_g, const bool bias_g) {
1534 const auto stride = input.sizes()[1];
1535 const auto reduction_size = input.numel() / stride;
1536
1537 at::Tensor sumn_dy = at::empty({stride}, mean.options());
1538 at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
1539
1540 at::Tensor grad_weight;
1541 at::Tensor grad_bias;
1542 if (weight.defined()) {
1543 grad_weight = at::empty({stride}, weight.options());
1544 grad_bias = at::empty({stride}, weight.options());
1545 } else {
1546 // because I cannot return an uninitialized at::Tensor
1547 grad_weight = at::empty({0}, mean.options());
1548 grad_bias = at::empty({0}, mean.options());
1549 }
1550
1551 dim3 block;
1552 dim3 grid;
1553 flexible_launch_configs(reduction_size, stride, block, grid, true);
1554
1555 at::Tensor staging_data;
1556 at::Tensor semaphores;
1557 if (grid.y > 1) {
1558 staging_data = at::empty({2*stride*grid.y}, mean.options());
1559 semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1560 }
1561 auto stream = at::cuda::getCurrentCUDAStream();
1562
1563 if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
1564 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1565 using accscalar_t = at::acc_type<scalar_t, true>;
1566 accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1567 int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1568 batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1569 <<<grid, block, 0, stream>>>(
1570 input.const_data_ptr<scalar_t>(),
1571 grad_output.const_data_ptr<scalar_t>(),
1572 mean.const_data_ptr<accscalar_t>(),
1573 inv_std.const_data_ptr<accscalar_t>(),
1574 sumn_dy.mutable_data_ptr<accscalar_t>(),
1575 sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1576 grad_weight.mutable_data_ptr<accscalar_t>(),
1577 grad_bias.mutable_data_ptr<accscalar_t>(),
1578 staging_data_ptr,
1579 semaphores_ptr,
1580 reduction_size,
1581 stride);
1582 C10_CUDA_KERNEL_LAUNCH_CHECK();
1583 });
1584 } else {
1585 if (weight.defined()) {
1586 TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
1587 " is not supported with weight.scalar_type() ", weight.scalar_type());
1588 }
1589 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1590 using accscalar_t = at::acc_type<scalar_t, true>;
1591 accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1592 int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1593 batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1594 <<<grid, block, 0, stream>>>(
1595 input.const_data_ptr<scalar_t>(),
1596 grad_output.const_data_ptr<scalar_t>(),
1597 mean.const_data_ptr<accscalar_t>(),
1598 inv_std.const_data_ptr<accscalar_t>(),
1599 sumn_dy.mutable_data_ptr<accscalar_t>(),
1600 sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1601 weight.defined() ? grad_weight.mutable_data_ptr<scalar_t>() : nullptr,
1602 weight.defined() ? grad_bias.mutable_data_ptr<scalar_t>() : nullptr,
1603 staging_data_ptr,
1604 semaphores_ptr,
1605 reduction_size,
1606 stride);
1607 C10_CUDA_KERNEL_LAUNCH_CHECK();
1608 });
1609 }
1610
1611 return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
1612 }
1613
batch_norm_backward_elemt_channels_last_cuda_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const at::Tensor & sum_dy,const at::Tensor & sum_dy_xmu,const at::Tensor & count)1614 at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1615 const at::Tensor& grad_output,
1616 const at::Tensor& input,
1617 const at::Tensor& mean,
1618 const at::Tensor& inv_std,
1619 const at::Tensor& weight,
1620 const at::Tensor& sum_dy,
1621 const at::Tensor& sum_dy_xmu,
1622 const at::Tensor& count) {
1623 const auto stride = input.sizes()[1];
1624 const auto reduction_size = input.numel() / stride;
1625
1626 // Input is guarunteed to be channels-last compatible
1627 at::Tensor grad_input = at::empty_like(input);
1628
1629 dim3 block;
1630 dim3 grid;
1631 flexible_launch_configs(reduction_size, stride, block, grid);
1632
1633 auto stream = at::cuda::getCurrentCUDAStream();
1634
1635 if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1636 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1637 using accscalar_t = at::acc_type<scalar_t, true>;
1638 batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1639 <<<grid, block, 0, stream>>>(
1640 grad_output.const_data_ptr<scalar_t>(),
1641 input.const_data_ptr<scalar_t>(),
1642 mean.const_data_ptr<accscalar_t>(),
1643 inv_std.const_data_ptr<accscalar_t>(),
1644 weight.const_data_ptr<accscalar_t>(),
1645 sum_dy.const_data_ptr<accscalar_t>(),
1646 sum_dy_xmu.const_data_ptr<accscalar_t>(),
1647 count.const_data_ptr<int>(),
1648 grad_input.mutable_data_ptr<scalar_t>(),
1649 count.numel(),
1650 reduction_size,
1651 stride);
1652 C10_CUDA_KERNEL_LAUNCH_CHECK();
1653 });
1654 } else {
1655 if (weight.defined()) {
1656 TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
1657 " is not supported with weight.scalar_type() ", weight.scalar_type());
1658 }
1659 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1660 using accscalar_t = at::acc_type<scalar_t, true>;
1661 batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1662 <<<grid, block, 0, stream>>>(
1663 grad_output.const_data_ptr<scalar_t>(),
1664 input.const_data_ptr<scalar_t>(),
1665 mean.const_data_ptr<accscalar_t>(),
1666 inv_std.const_data_ptr<accscalar_t>(),
1667 weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1668 sum_dy.const_data_ptr<accscalar_t>(),
1669 sum_dy_xmu.const_data_ptr<accscalar_t>(),
1670 count.const_data_ptr<int>(),
1671 grad_input.mutable_data_ptr<scalar_t>(),
1672 count.numel(),
1673 reduction_size,
1674 stride);
1675 C10_CUDA_KERNEL_LAUNCH_CHECK();
1676 });
1677 }
1678
1679 return grad_input;
1680 }
1681
batch_norm_backward_elemt_channels_last_cuda_template(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & mean,const at::Tensor & inv_std,const at::Tensor & weight,const at::Tensor & sum_dy,const at::Tensor & sum_dy_xmu)1682 at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1683 const at::Tensor& grad_output,
1684 const at::Tensor& input,
1685 const at::Tensor& mean,
1686 const at::Tensor& inv_std,
1687 const at::Tensor& weight,
1688 const at::Tensor& sum_dy,
1689 const at::Tensor& sum_dy_xmu) {
1690 const auto stride = input.sizes()[1];
1691 const auto reduction_size = input.numel() / stride;
1692 auto norm_fct = 1.0 / reduction_size;
1693
1694 // Input is guarunteed to be channels-last compatible
1695 at::Tensor grad_input = at::empty_like(input);
1696
1697 dim3 block;
1698 dim3 grid;
1699 flexible_launch_configs(reduction_size, stride, block, grid);
1700
1701 auto stream = at::cuda::getCurrentCUDAStream();
1702
1703 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1704 using accscalar_t = at::acc_type<scalar_t, true>;
1705
1706 if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1707 batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1708 <<<grid, block, 0, stream>>>(
1709 grad_output.const_data_ptr<scalar_t>(),
1710 input.const_data_ptr<scalar_t>(),
1711 mean.const_data_ptr<accscalar_t>(),
1712 inv_std.const_data_ptr<accscalar_t>(),
1713 weight.const_data_ptr<accscalar_t>(),
1714 sum_dy.const_data_ptr<accscalar_t>(),
1715 sum_dy_xmu.const_data_ptr<accscalar_t>(),
1716 grad_input.mutable_data_ptr<scalar_t>(),
1717 static_cast<accscalar_t>(norm_fct),
1718 reduction_size,
1719 stride);
1720 C10_CUDA_KERNEL_LAUNCH_CHECK();
1721 } else {
1722 batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1723 <<<grid, block, 0, stream>>>(
1724 grad_output.const_data_ptr<scalar_t>(),
1725 input.const_data_ptr<scalar_t>(),
1726 mean.const_data_ptr<accscalar_t>(),
1727 inv_std.const_data_ptr<accscalar_t>(),
1728 weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1729 sum_dy.const_data_ptr<accscalar_t>(),
1730 sum_dy_xmu.const_data_ptr<accscalar_t>(),
1731 grad_input.mutable_data_ptr<scalar_t>(),
1732 static_cast<accscalar_t>(norm_fct),
1733 reduction_size,
1734 stride);
1735 C10_CUDA_KERNEL_LAUNCH_CHECK();
1736 }
1737 });
1738
1739 return grad_input;
1740 }
1741
1742 } } // namespace at::native
1743