xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Normalization.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/detail/IndexUtils.cuh>
3 #include <ATen/detail/CUDAHooksInterface.h>
4 #include <ATen/native/Normalization.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/native/Resize.h>
8 #include <ATen/native/cuda/Loops.cuh>
9 #include <ATen/native/cuda/Resize.h>
10 #include <ATen/native/cuda/Normalization.cuh>
11 #include <c10/cuda/CUDAMathCompat.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_batch_norm_with_update_native.h>
18 #include <ATen/ops/batch_norm_backward_native.h>
19 #include <ATen/ops/batch_norm_backward_elemt_native.h>
20 #include <ATen/ops/batch_norm_backward_reduce_native.h>
21 #include <ATen/ops/batch_norm_elemt_native.h>
22 #include <ATen/ops/batch_norm_gather_stats_native.h>
23 #include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
24 #include <ATen/ops/batch_norm_stats_native.h>
25 #include <ATen/ops/batch_norm_update_stats_native.h>
26 #include <ATen/ops/cudnn_batch_norm.h>
27 #include <ATen/ops/cudnn_batch_norm_backward.h>
28 #include <ATen/ops/empty_like.h>
29 #include <ATen/ops/from_blob.h>
30 #include <ATen/ops/miopen_batch_norm.h>
31 #include <ATen/ops/miopen_batch_norm_backward.h>
32 #include <ATen/ops/native_batch_norm_backward_native.h>
33 #include <ATen/ops/native_batch_norm_native.h>
34 #include <ATen/ops/scalar_tensor.h>
35 #endif
36 
37 namespace at::native {
38 
39 namespace {
40 
first_type()41 ScalarType first_type() {
42   return ScalarType::Undefined;
43 }
44 
45 template <typename... Args>
first_type(const Tensor & arg,const Args &...parameters)46 ScalarType first_type(const Tensor& arg, const Args&... parameters) {
47   return arg.defined() ? arg.scalar_type() : first_type(parameters...);
48 }
49 
50 // A transform is mixed type if the parameters are higher precision than the input
51 template <typename... Args>
is_mixed_type(const Tensor & input,const Args &...parameters)52 bool is_mixed_type(const Tensor& input, const Args&... parameters) {
53   const auto parameter_type = first_type(parameters...);
54   return ((parameter_type != ScalarType::Undefined) &&
55           (parameter_type != input.scalar_type()));
56 }
57 
batch_norm_use_channels_last_kernels(const at::Tensor & self)58 inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
59   return (
60     self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
61     self.is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
62     (self.is_contiguous() && self.strides()[1] == 1)
63   );
64 }
65 
66 enum class Impl {
67   Contiguous,
68   ChannelsLast,
69   General,
70 };
71 
batch_norm_choose_impl(const Tensor & self)72 inline Impl batch_norm_choose_impl(const Tensor& self) {
73   if (!at::cuda::detail::canUse32BitIndexMath(self)) {
74     return Impl::General;
75   }
76 
77   if (self.is_contiguous()) {
78     return self.strides()[1] == 1 ? Impl::ChannelsLast : Impl::Contiguous;
79   }
80 
81   if (self.is_contiguous(at::MemoryFormat::ChannelsLast)) {
82     return Impl::ChannelsLast;
83   }
84 
85   return Impl::General;
86 }
87 
batch_norm_choose_impl(const Tensor & in1,const Tensor & in2)88 inline Impl batch_norm_choose_impl(const Tensor& in1, const Tensor& in2) {
89   auto imp1 = batch_norm_choose_impl(in1);
90   if (imp1 == Impl::General) {
91     return imp1;
92   }
93   auto imp2 = batch_norm_choose_impl(in2);
94   return imp1 == imp2 ? imp1 : Impl::General;
95 }
96 
batch_norm_elementwise(const Tensor & out,const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean_,const Tensor & invstd_)97 void batch_norm_elementwise(
98     const Tensor& out, const Tensor& self, const std::optional<Tensor>& weight_opt,
99     const std::optional<Tensor>& bias_opt, const Tensor& mean_, const Tensor& invstd_) {
100   switch (batch_norm_choose_impl(self)) {
101   case Impl::Contiguous: {
102     c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
103     c10::MaybeOwned<Tensor> bias = at::borrow_from_optional_tensor(bias_opt);
104     resize_output(out, self.sizes());
105     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
106                                     "batch_norm_elementwise_cuda", [&] {
107       using accscalar_t = at::acc_type<scalar_t, true>;
108       const bool mixed_type = is_mixed_type(self, *weight, *bias);
109       if (mixed_type) {
110         batch_norm_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(
111             out, self, *weight, *bias, mean_, invstd_);
112       } else {
113         batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
114             out, self, *weight, *bias, mean_, invstd_);
115       }
116     });
117     return;
118   }
119   case Impl::ChannelsLast: {
120     auto weight = at::borrow_from_optional_tensor(weight_opt);
121     auto bias = at::borrow_from_optional_tensor(bias_opt);
122 
123     if (resize_output_check(out, self.sizes())) {
124         resize_impl_cuda_(out.unsafeGetTensorImpl(), self.sizes(), self.strides());
125     }
126     if ((out.strides() == self.strides()) &&
127         (!weight->defined() || weight->is_contiguous()) &&
128         (!bias->defined() || bias->is_contiguous()) &&
129         (!mean_.defined() || mean_.is_contiguous()) &&
130         (!invstd_.defined() || invstd_.is_contiguous())) {
131       batch_norm_elemt_channels_last_cuda_template(
132           out, self, *weight, *bias, mean_, invstd_);
133       return;
134     }
135     [[fallthrough]];
136   }
137   case Impl::General: {
138     const int64_t ndim = self.dim();
139     DimVector sizes(ndim, 1), strides(ndim, 0);
140     // Helper to convert 1d tensors to an nd tensor that broadcasts with input
141     // All elements go into the channel dimension
142     auto as_nd = [&](const Tensor& t) {
143       TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
144       sizes[1] = t.sizes()[0];
145       strides[1] = t.strides()[0];
146       return t.as_strided(sizes, strides);
147     };
148 
149     auto weight = weight_opt.has_value() && weight_opt->defined() ?
150         as_nd(*weight_opt) : at::scalar_tensor(1, mean_.options());
151     auto bias = bias_opt.has_value() && bias_opt->defined() ?
152         as_nd(*bias_opt) : at::scalar_tensor(0, mean_.options());
153     auto mean = as_nd(mean_);
154     auto invstd = as_nd(invstd_);
155 
156     auto iter = TensorIteratorConfig()
157         .add_output(out)
158         .add_input(self)
159         .add_input(weight)
160         .add_input(bias)
161         .add_input(mean)
162         .add_input(invstd)
163         .check_all_same_dtype(false)
164         .promote_inputs_to_common_dtype(false)
165         .build();
166 
167     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
168                                     "batch_norm_elementwise_cuda", [&] {
169       using acc_t = at::acc_type<scalar_t, true>;
170       gpu_kernel(iter, [] GPU_LAMBDA (scalar_t input, acc_t weight, acc_t bias,
171                                       acc_t mean, acc_t invstd) -> scalar_t {
172         return (input - mean) * weight * invstd + bias;
173       });
174     });
175     return;
176   }
177   }
178 }
179 
batch_norm_elementwise_backward_train(const Tensor & grad_out,const Tensor & input,const Tensor & mean,const Tensor & invstd,const Tensor & weight,const Tensor & sum_dy,const Tensor & sum_dy_xmu)180 Tensor batch_norm_elementwise_backward_train(
181     const Tensor& grad_out, const Tensor& input, const Tensor& mean, const Tensor& invstd,
182     const Tensor& weight, const Tensor& sum_dy, const Tensor& sum_dy_xmu) {
183   switch (batch_norm_choose_impl(input, grad_out)) {
184   case Impl::Contiguous: {
185     return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
186                                            "batch_norm_backward_elemt", [&] {
187       using accscalar_t = at::acc_type<scalar_t, true>;
188       const bool mixed_type = is_mixed_type(input, weight);
189       if (mixed_type) {
190         return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(
191             grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
192       } else {
193         return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
194             grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
195       }
196     });
197   }
198   case Impl::ChannelsLast: {
199     if ((!weight.defined() || weight.is_contiguous()) &&
200         mean.is_contiguous() && invstd.is_contiguous()) {
201       return batch_norm_backward_elemt_channels_last_cuda_template(
202           grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
203     }
204     [[fallthrough]];
205   }
206   case Impl::General: {
207     const auto ndim = input.dim();
208     DimVector sizes(ndim, 1), strides(ndim, 0);
209     auto as_nd = [&](const Tensor& t) {
210       TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
211       sizes[1] = t.sizes()[0];
212       strides[1] = t.strides()[0];
213       return t.as_strided(sizes, strides);
214     };
215     auto invstd_nd = as_nd(invstd);
216     auto mean_nd = as_nd(mean);
217     auto sum_dy_nd = as_nd(sum_dy);
218     auto sum_dy_xmu_nd = as_nd(sum_dy_xmu);
219     auto weight_nd = weight.defined() ? as_nd(weight) :
220         at::scalar_tensor(1.0, input.options().dtype(mean.scalar_type()));
221 
222     Tensor grad_input = at::empty(input.sizes(), grad_out.options().memory_format(input.suggest_memory_format()));
223     auto iter = TensorIteratorConfig()
224         .add_output(grad_input)
225         .add_input(grad_out)
226         .add_input(input)
227         .add_input(weight_nd)
228         .add_input(mean_nd)
229         .add_input(invstd_nd)
230         .add_input(sum_dy_xmu_nd)
231         .add_input(sum_dy_nd)
232         .check_all_same_dtype(false)
233         .promote_inputs_to_common_dtype(false)
234         .build();
235 
236     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
237                                     "batch_norm_eval_backward", [&]{
238       using accscalar_t = at::acc_type<scalar_t, true>;
239       auto norm_fct = static_cast<accscalar_t>(1.0 / (input.numel() /input.size(1)) );
240       gpu_kernel(iter, [norm_fct] GPU_LAMBDA (scalar_t gO, scalar_t input, accscalar_t weight,
241                                               accscalar_t mean, accscalar_t invstd,
242                                               accscalar_t xmu, accscalar_t dy) -> scalar_t {
243         auto factor_1_c = invstd * invstd * xmu * norm_fct;
244         auto factor_2_c = weight * invstd;
245         auto m_dy_c = dy * norm_fct;
246         return (gO - m_dy_c - (input - mean) * factor_1_c) * factor_2_c;
247       });
248     });
249     return grad_input;
250   }
251   }
252   TORCH_INTERNAL_ASSERT(false);
253 }
254 
batch_norm_elementwise_backward_eval(const Tensor & grad_out,const Tensor & input,const Tensor & invstd,const Tensor & weight)255 Tensor batch_norm_elementwise_backward_eval(
256     const Tensor& grad_out, const Tensor& input,
257     const Tensor& invstd, const Tensor& weight) {
258   const auto ndim = input.dim();
259   DimVector shape(ndim, 1), strides(ndim, 0);
260   shape[1] = invstd.sizes()[0];
261   strides[1] = invstd.strides()[0];
262   auto invstd_nd = invstd.as_strided(shape, strides);
263   Tensor grad_input = at::empty(input.sizes(), grad_out.options());
264 
265   if (weight.defined()) {
266     strides[1] = weight.strides()[0];
267     auto weight_nd = weight.as_strided(shape, strides);
268     auto iter = TensorIteratorConfig()
269         .add_output(grad_input)
270         .add_const_input(grad_out)
271         .add_const_input(invstd_nd)
272         .add_const_input(weight_nd)
273         .check_all_same_dtype(false)
274         .promote_inputs_to_common_dtype(false)
275         .build();
276 
277     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
278                                     "batch_norm_eval_backward", [&]{
279       using accscalar_t = at::acc_type<scalar_t, true>;
280       gpu_kernel(iter, [] GPU_LAMBDA (scalar_t gO, accscalar_t invstd, accscalar_t weight)
281                  -> scalar_t {
282           return gO * weight * invstd;
283       });
284     });
285   } else {
286     auto iter = TensorIteratorConfig()
287         .add_output(grad_input)
288         .add_const_input(grad_out)
289         .add_const_input(invstd_nd)
290         .check_all_same_dtype(false)
291         .promote_inputs_to_common_dtype(false)
292         .build();
293 
294     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
295                                     "batch_norm_eval_backward", [&]{
296       using accscalar_t = at::acc_type<scalar_t, true>;
297       gpu_kernel(iter, [] GPU_LAMBDA (scalar_t gO, accscalar_t invstd) -> scalar_t {
298           return gO * invstd;
299       });
300     });
301   }
302   return grad_input;
303 }
304 
305 
batch_norm_mean_var(const Tensor & self,Tensor & save_mean,Tensor & save_var)306 void batch_norm_mean_var(const Tensor& self, Tensor& save_mean, Tensor& save_var) {
307   // NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored.
308   const double dummy_epsilon = 1e-5;
309   switch (batch_norm_choose_impl(self)) {
310   case Impl::Contiguous: {
311     AT_DISPATCH_FLOATING_TYPES_AND2(
312         kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
313       batch_norm_stats_cuda_template<scalar_t, int32_t, Var>(
314           save_mean, save_var, self, dummy_epsilon);
315     });
316     return;
317   }
318   case Impl::ChannelsLast: {
319     if ((!save_mean.defined() || save_mean.is_contiguous()) &&
320         (!save_var.defined() || save_var.is_contiguous())) {
321       AT_DISPATCH_FLOATING_TYPES_AND2(
322           kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
323         batch_norm_stats_channels_last_cuda_template<scalar_t, Var>(
324             save_mean, save_var, self, dummy_epsilon);
325       });
326       return;
327     }
328     [[fallthrough]];
329   }
330   case Impl::General: {
331     const int64_t ndim = self.dim();
332     DimVector reduce_dims(ndim - 1);
333     reduce_dims[0] = 0;
334     for (int64_t i = 2; i < ndim; ++i) {
335       reduce_dims[i - 1] = i;
336     }
337 
338     // For some reason this isn't an actual operator but it exists anyway...
339     at::native::var_mean_out(save_var, save_mean, self, /*dims=*/reduce_dims,
340                             /*unbiased=*/false, /*keepdim=*/false);
341     return;
342   }
343   }
344 }
345 
batch_norm_update_stats(const Tensor & save_mean,const Tensor & save_var,const Tensor & running_mean,const Tensor & running_var,double momentum_,int64_t N)346 void batch_norm_update_stats(
347     const Tensor& save_mean, const Tensor& save_var,
348     const Tensor& running_mean, const Tensor& running_var,
349     double momentum_, int64_t N) {
350 
351   auto iter = TensorIteratorConfig()
352       .add_output(running_mean)
353       .add_output(running_var)
354       .add_input(save_mean)
355       .add_input(save_var)
356       .add_input(running_mean)
357       .add_input(running_var)
358       .check_all_same_dtype(false)
359       .promote_inputs_to_common_dtype(false)
360       .build();
361 
362   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
363                                   "batch_norm_update_stats_cuda", [&] {
364       using acc_t = at::acc_type<scalar_t, true>;
365       const auto bessel_correction_factor = static_cast<acc_t>(
366           static_cast<double>(N) / static_cast<double>(N - 1));
367       const auto momentum = static_cast<acc_t>(momentum_);
368       gpu_kernel_multiple_outputs(
369           iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
370                -> thrust::tuple<scalar_t, scalar_t> {
371         const auto unbiased_var = var * bessel_correction_factor;
372         return thrust::tuple<scalar_t, scalar_t>{
373           mean * momentum + (1 - momentum) * running_mean,
374           unbiased_var * momentum + (1 - momentum) * running_var,
375         };
376       });
377   });
378 }
379 
batch_norm_update_stats_and_invert(const Tensor & save_mean,const Tensor & save_var,const Tensor & running_mean,const Tensor & running_var,double momentum_,double epsilon,int64_t N)380 void batch_norm_update_stats_and_invert(
381     const Tensor& save_mean, const Tensor& save_var,
382     const Tensor& running_mean, const Tensor& running_var,
383     double momentum_, double epsilon, int64_t N) {
384 
385   auto iter = TensorIteratorConfig()
386       .add_output(running_mean)
387       .add_output(running_var)
388       .add_output(save_var)
389       .add_const_input(save_mean)
390       .add_input(save_var)
391       .add_input(running_mean)
392       .add_input(running_var)
393       .check_all_same_dtype(false)
394       .promote_inputs_to_common_dtype(false)
395       .build();
396 
397   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
398                                   "batch_norm_update_stats_cuda", [&] {
399       using acc_t = at::acc_type<scalar_t, true>;
400       const auto bessel_correction_factor = static_cast<acc_t>(
401           static_cast<double>(N) / static_cast<double>(N - 1));
402       const auto eps = static_cast<acc_t>(epsilon);
403       const auto momentum = static_cast<acc_t>(momentum_);
404       gpu_kernel_multiple_outputs(
405           iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
406                -> thrust::tuple<scalar_t, scalar_t, acc_t> {
407         const auto unbiased_var = var * bessel_correction_factor;
408         return thrust::tuple<scalar_t, scalar_t, acc_t>{
409           mean * momentum + (1 - momentum) * running_mean,
410           unbiased_var * momentum + (1 - momentum) * running_var,
411           c10::cuda::compat::rsqrt(var + eps)
412         };
413       });
414   });
415 }
416 
batch_norm_calc_invstd(const Tensor & out_invstd,const Tensor & running_var,double epsilon)417 void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var, double epsilon) {
418   auto iter = TensorIteratorConfig()
419       .add_output(out_invstd)
420       .add_input(running_var)
421       .check_all_same_dtype(false)
422       .build();
423 
424   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_var.scalar_type(),
425                                   "batch_norm_invert_std_cuda", [&] {
426     using acc_t = at::acc_type<scalar_t, true>;
427     auto eps = static_cast<acc_t>(epsilon);
428     gpu_kernel(iter, [eps] GPU_LAMBDA (scalar_t var) -> acc_t {
429       return c10::cuda::compat::rsqrt(var + eps);
430     });
431   });
432 }
433 }
434 
batch_norm_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)435 std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
436   const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
437   const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
438   TORCH_CHECK(has_running_mean == has_running_var);
439 
440   if (train) {
441     batch_norm_mean_var(self, save_mean, save_invstd);
442     if (has_running_mean) {
443       const int64_t N = self.numel() / save_mean.numel();
444       batch_norm_update_stats_and_invert(
445           save_mean, save_invstd, *running_mean_opt, *running_var_opt,
446           momentum, epsilon, N);
447     } else {
448       batch_norm_calc_invstd(save_invstd, save_invstd, epsilon);
449     }
450   } else {
451     TORCH_CHECK(has_running_mean);
452     at::native::resize_output(save_mean, running_mean_opt->sizes());
453     save_mean.copy_(*running_mean_opt, /*non_blocking=*/true);
454     batch_norm_calc_invstd(save_invstd, running_var_opt.value(), epsilon);
455   }
456 
457   batch_norm_elementwise(output, self, weight_opt, bias_opt, save_mean, save_invstd);
458   return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd);
459 }
460 
batch_norm_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double epsilon)461 std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) {
462   auto output = at::empty_like(self);
463   int64_t n_input = self.size(1);
464   auto options = self.options().dtype(
465       at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
466   auto save_mean = at::empty({n_input}, options);
467   auto save_invstd = at::empty({n_input}, options);
468 
469   at::native::batch_norm_cuda_out(
470       self,
471       weight_opt,
472       bias_opt,
473       running_mean_opt,
474       running_var_opt,
475       train,
476       momentum,
477       epsilon,
478       output,
479       save_mean,
480       save_invstd);
481   return std::make_tuple(output, save_mean, save_invstd);
482 }
483 
_batch_norm_with_update_cuda(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)484 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
485     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
486     Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
487   // See [Note: hacky wrapper removal for optional tensor]
488   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
489   const Tensor& weight = *weight_maybe_owned;
490   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
491   Tensor output, save_mean, save_var, reserve;
492 
493   BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
494   if (backend == BatchNormBackend::Cudnn) {
495     return at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
496   }
497   if (backend == BatchNormBackend::Miopen) {
498     reserve = at::empty({0}, input.options().dtype(kByte));
499     std::tie(output, save_mean, save_var) =
500         at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
501   } else {
502     reserve = at::empty({0}, input.options().dtype(kByte));
503     std::tie(output, save_mean, save_var) =
504         batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
505   }
506   return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
507 }
508 
_batch_norm_with_update_cuda_out(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var,Tensor & reserve)509 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
510     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
511     Tensor& running_mean, Tensor& running_var, double momentum, double eps,
512     Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
513   // See [Note: hacky wrapper removal for optional tensor]
514   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
515   const Tensor& weight = *weight_maybe_owned;
516   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
517 
518   BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
519   if (backend == BatchNormBackend::Cudnn) {
520     std::tie(out, save_mean, save_var, reserve) =
521         at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
522   } else if (backend == BatchNormBackend::Miopen) {
523     std::tie(out, save_mean, save_var) =
524         at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
525   } else {
526     std::tie(out, save_mean, save_var) =
527       batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
528   }
529   return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
530 }
531 
_batch_norm_legit_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double epsilon)532 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
533   return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
534 }
535 
_batch_norm_legit_no_stats_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double epsilon)536 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon) {
537   return batch_norm_cuda(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon);
538 }
539 
_batch_norm_legit_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)540 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
541   return batch_norm_cuda_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_invstd);
542 }
543 
_batch_norm_legit_no_stats_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)544 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
545   return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
546 }
547 
_new_batch_norm_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)548 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
549     const Tensor& grad_output, const Tensor& input, const Tensor& weight,
550     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
551     const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
552     bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
553   const Tensor& dummy_bias = at::empty(1);
554   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
555   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
556   const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
557   const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
558 
559   BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
560 
561   if (backend == BatchNormBackend::Cudnn) {
562     return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
563   } else if (backend == BatchNormBackend::Miopen) {
564     return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
565   } else {
566     return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
567   }
568 }
569 
batch_norm_backward_cuda(const Tensor & grad_out,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double epsilon,std::array<bool,3> grad_input_mask)570 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
571   // See [Note: hacky wrapper removal for optional tensor]
572   c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
573   c10::MaybeOwned<Tensor> save_mean = at::borrow_from_optional_tensor(save_mean_opt);
574   c10::MaybeOwned<Tensor> save_invstd = at::borrow_from_optional_tensor(save_invstd_opt);
575   c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt);
576   c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt);
577 
578   const bool needs_reduction = train || grad_input_mask[1] || grad_input_mask[2];
579 
580   // Fused reduction & elementwise kernel
581   if (needs_reduction && grad_input_mask[0] &&
582       !batch_norm_use_channels_last_kernels(input) &&
583       cuda::detail::canUse32BitIndexMath(input) &&
584       cuda::detail::canUse32BitIndexMath(grad_out)) {
585     return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
586                                            "batch_norm_backward_cuda", [&] {
587       using accscalar_t = at::acc_type<scalar_t, true>;
588       const bool mixed_type = is_mixed_type(input, *weight, *running_mean, *running_var);
589       if (mixed_type) {
590           return batch_norm_backward_cuda_template<scalar_t, accscalar_t, int32_t>(
591               grad_out, input, *weight, *running_mean, *running_var,
592               *save_mean, *save_invstd, train, epsilon, grad_input_mask);
593       } else {
594           return batch_norm_backward_cuda_template<scalar_t, scalar_t, int32_t>(
595               grad_out, input, *weight, *running_mean, *running_var,
596               *save_mean, *save_invstd, train, epsilon, grad_input_mask);
597       }
598     });
599   }
600 
601   // NOTE: native_batch_norm always returns save_mean and save_invstd to be reused in backward.
602   // However, this is also called from cudnn_batch_norm in eval mode which doesn't give
603   // save_mean and save_invstd, so it needs recalculated.
604   const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
605   Tensor mean;
606   TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n");
607   if (save_mean->numel() != 0) {
608     mean = *save_mean;
609   } else if (needs_reduction) {
610     TORCH_CHECK(!train && running_mean->defined());
611     mean = (running_mean->scalar_type() == acc_type) ?
612         *running_mean : running_mean->to(acc_type);
613   }
614 
615   Tensor invstd;
616   TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n");
617   if (save_invstd->numel() != 0) {
618     invstd = *save_invstd;
619   } else {
620     TORCH_CHECK(!train && running_var->defined());
621     auto n_channels = input.sizes()[1];
622     invstd = at::empty({n_channels}, input.options().dtype(acc_type));
623     batch_norm_calc_invstd(invstd, *running_var, epsilon);
624   }
625 
626   Tensor sum_dy, sum_dy_xmu, grad_weight, grad_bias;
627   if (needs_reduction) {
628     std::tie(sum_dy, sum_dy_xmu, grad_weight, grad_bias) =
629         batch_norm_backward_reduce_cuda(
630             grad_out, input, mean, invstd, *weight,
631             grad_input_mask[0], grad_input_mask[1], grad_input_mask[2]);
632   }
633 
634   Tensor grad_input;
635   if (grad_input_mask[0]) {
636     if (train) {
637       // NOTE: sum_dy and sum_dy_xmy are defined, as train implies needs_reduction
638       grad_input = batch_norm_elementwise_backward_train(
639           grad_out, input, mean, invstd, *weight, sum_dy, sum_dy_xmu);
640     } else {
641       grad_input = batch_norm_elementwise_backward_eval(
642           grad_out, input, invstd, *weight);
643     }
644   }
645 
646   return std::make_tuple(grad_input, grad_weight, grad_bias);
647 }
648 
batch_norm_stats_cuda(const Tensor & self,double epsilon)649 std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
650   auto options = self.options().dtype(
651       at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
652   auto n_channels = self.size(1);
653   auto save_mean = at::empty({n_channels}, options);
654   auto save_invstd = at::empty({n_channels}, options);
655 
656   bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self);
657   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
658                                   self.scalar_type(), "batch_norm_stats_cuda", [&] {
659     if (cuda::detail::canUse32BitIndexMath(self)) {
660       if (use_channels_last_kernel) {
661         batch_norm_stats_channels_last_cuda_template<scalar_t, InvStd>(
662             save_mean, save_invstd, self, epsilon);
663       } else {
664         batch_norm_stats_cuda_template<scalar_t, int32_t, InvStd>(
665             save_mean, save_invstd, self, epsilon);
666       }
667     } else {
668       batch_norm_stats_cuda_template<scalar_t, int64_t, InvStd>(
669           save_mean, save_invstd, self, epsilon);
670     }
671   });
672   return std::tuple<Tensor, Tensor>(save_mean, save_invstd);
673 }
674 
batch_norm_elemt_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & invstd,double epsilon)675 Tensor batch_norm_elemt_cuda(
676     const Tensor& self, const std::optional<Tensor>& weight_opt,
677     const std::optional<Tensor>& bias_opt, const Tensor& mean,
678     const Tensor& invstd, double epsilon) {
679   auto output = at::empty_like(self);
680   // FIXME: Epsilon parameter isn't required, we don't take the reciprocal
681   batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
682   return output;
683 }
684 
batch_norm_elemt_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & invstd,double epsilon,Tensor & output)685 Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
686                                   const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) {
687   // FIXME: Epsilon parameter isn't required, we don't take the reciprocal
688   batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
689   return output;
690 }
691 
692 // accepting input(self) here to determine template data types, since running_mean/running_var are optional
batch_norm_gather_stats_cuda(const Tensor & self,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double epsilon,int64_t count)693 std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, double momentum, double epsilon, int64_t count) {
694   // See [Note: hacky wrapper removal for optional tensor]
695   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
696   const Tensor& running_mean = *running_mean_maybe_owned;
697   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
698 
699   std::vector<int64_t> counts(mean.size(0), count);
700   Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU));
701   counts_ = counts_.to(self.device()).to(running_mean.defined() ? running_mean.dtype() : self.dtype());
702   return batch_norm_gather_stats_with_counts_cuda(self, mean, invstd, running_mean, running_var, momentum, epsilon, counts_);
703 }
704 
705 
batch_norm_gather_stats_with_counts_cuda(const Tensor & self,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double epsilon,const Tensor & counts)706 std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_cuda(
707     const Tensor& self, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, double momentum, double epsilon, const Tensor& counts) {
708   // See [Note: hacky wrapper removal for optional tensor]
709   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
710   const Tensor& running_mean = *running_mean_maybe_owned;
711   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
712 
713 
714   auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type();
715   return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "batch_norm_update_stats_cuda", [&] {
716     using accscalar_t = at::acc_type<scalar_t, true>;
717     if (cuda::detail::canUse32BitIndexMath(self)) {
718       return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int32_t>(mean, invstd, running_mean, running_var, momentum, epsilon, counts);
719     } else {
720       return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int64_t>(mean, invstd, running_mean, running_var, momentum, epsilon, counts);
721     }
722   });
723 }
724 
batch_norm_backward_reduce_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & weight_opt,bool input_g,bool weight_g,bool bias_g)725 std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& weight_opt, bool input_g, bool weight_g, bool bias_g) {
726   // See [Note: hacky wrapper removal for optional tensor]
727   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
728   const Tensor& weight = *weight_maybe_owned;
729 
730   if (at::cuda::detail::canUse32BitIndexMath(grad_output) &&
731       batch_norm_use_channels_last_kernels(grad_output) &&
732       batch_norm_use_channels_last_kernels(input) &&
733       (!weight.defined() || weight.is_contiguous()) &&
734       mean.is_contiguous() && invstd.is_contiguous()){
735     return batch_norm_backward_reduce_cuda_channels_last_template(
736         grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
737   }
738 
739   return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(), "batch_norm_backward_reduce", [&] {
740     auto mean_st = mean.dtype();
741     auto invstd_st = invstd.dtype();
742     TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
743     const bool mixed_type = is_mixed_type(input, weight);
744     using accscalar_t = at::acc_type<scalar_t, true>;
745 
746     if (cuda::detail::canUse32BitIndexMath(grad_output)) {
747       if (mixed_type) {
748         return batch_norm_backward_reduce_cuda_template<scalar_t, accscalar_t, int32_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
749       } else {
750         return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int32_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
751       }
752     } else {
753       if (mixed_type) {
754         return batch_norm_backward_reduce_cuda_template<scalar_t, accscalar_t, int64_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
755       } else {
756         return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int64_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
757       }
758     }
759   });
760 }
761 
batch_norm_backward_elemt_cuda(const Tensor & self,const Tensor & input,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & weight_opt,const Tensor & sum_dy,const Tensor & sum_dy_xmu,const Tensor & count)762 Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& weight_opt, const Tensor& sum_dy, const Tensor& sum_dy_xmu, const Tensor& count) {
763   // See [Note: hacky wrapper removal for optional tensor]
764   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
765   const Tensor& weight = *weight_maybe_owned;
766 
767   if (at::cuda::detail::canUse32BitIndexMath(self) &&
768       batch_norm_use_channels_last_kernels(self) &&
769       batch_norm_use_channels_last_kernels(input))  {
770     return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
771   }
772 
773   return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_elemt", [&] {
774     auto mean_st = mean.dtype();
775     auto invstd_st = invstd.dtype();
776     TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
777     bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
778     bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
779     using accscalar_t = at::acc_type<scalar_t, true>;
780     if (cuda::detail::canUse32BitIndexMath(self)) {
781       if (is_half_float || is_bfloat16_float) {
782         return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
783       } else {
784         return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int32_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
785       }
786     } else {
787       if (is_half_float || is_bfloat16_float) {
788         return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int64_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
789       } else {
790         return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int64_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
791       }
792     }
793   });
794 }
795 
batch_norm_update_stats_cuda(const Tensor & self,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum)796 std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
797     const Tensor& self, const std::optional<Tensor>& running_mean_opt,
798     const std::optional<Tensor>& running_var_opt, double momentum) {
799   c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt);
800   c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt);
801 
802   const int64_t n_input = self.size(1);
803 
804   TORCH_CHECK(self.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", self.sizes());
805   auto options = self.options().dtype(
806       at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
807   auto save_mean = at::empty({n_input}, options);
808   auto save_var = at::empty({n_input}, options);
809 
810   batch_norm_mean_var(self, save_mean, save_var);
811   TORCH_CHECK(running_mean->defined() == running_var->defined());
812   if (running_mean->defined()) {
813     const int64_t N = self.numel() / save_mean.numel();
814     batch_norm_update_stats(save_mean, save_var, *running_mean, *running_var, momentum, N);
815   }
816   return std::tuple<Tensor, Tensor>(save_mean, save_var);
817 }
818 
819 } // namespace at::native
820