1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/detail/IndexUtils.cuh>
3 #include <ATen/detail/CUDAHooksInterface.h>
4 #include <ATen/native/Normalization.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/native/Resize.h>
8 #include <ATen/native/cuda/Loops.cuh>
9 #include <ATen/native/cuda/Resize.h>
10 #include <ATen/native/cuda/Normalization.cuh>
11 #include <c10/cuda/CUDAMathCompat.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_batch_norm_with_update_native.h>
18 #include <ATen/ops/batch_norm_backward_native.h>
19 #include <ATen/ops/batch_norm_backward_elemt_native.h>
20 #include <ATen/ops/batch_norm_backward_reduce_native.h>
21 #include <ATen/ops/batch_norm_elemt_native.h>
22 #include <ATen/ops/batch_norm_gather_stats_native.h>
23 #include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
24 #include <ATen/ops/batch_norm_stats_native.h>
25 #include <ATen/ops/batch_norm_update_stats_native.h>
26 #include <ATen/ops/cudnn_batch_norm.h>
27 #include <ATen/ops/cudnn_batch_norm_backward.h>
28 #include <ATen/ops/empty_like.h>
29 #include <ATen/ops/from_blob.h>
30 #include <ATen/ops/miopen_batch_norm.h>
31 #include <ATen/ops/miopen_batch_norm_backward.h>
32 #include <ATen/ops/native_batch_norm_backward_native.h>
33 #include <ATen/ops/native_batch_norm_native.h>
34 #include <ATen/ops/scalar_tensor.h>
35 #endif
36
37 namespace at::native {
38
39 namespace {
40
first_type()41 ScalarType first_type() {
42 return ScalarType::Undefined;
43 }
44
45 template <typename... Args>
first_type(const Tensor & arg,const Args &...parameters)46 ScalarType first_type(const Tensor& arg, const Args&... parameters) {
47 return arg.defined() ? arg.scalar_type() : first_type(parameters...);
48 }
49
50 // A transform is mixed type if the parameters are higher precision than the input
51 template <typename... Args>
is_mixed_type(const Tensor & input,const Args &...parameters)52 bool is_mixed_type(const Tensor& input, const Args&... parameters) {
53 const auto parameter_type = first_type(parameters...);
54 return ((parameter_type != ScalarType::Undefined) &&
55 (parameter_type != input.scalar_type()));
56 }
57
batch_norm_use_channels_last_kernels(const at::Tensor & self)58 inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
59 return (
60 self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
61 self.is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
62 (self.is_contiguous() && self.strides()[1] == 1)
63 );
64 }
65
66 enum class Impl {
67 Contiguous,
68 ChannelsLast,
69 General,
70 };
71
batch_norm_choose_impl(const Tensor & self)72 inline Impl batch_norm_choose_impl(const Tensor& self) {
73 if (!at::cuda::detail::canUse32BitIndexMath(self)) {
74 return Impl::General;
75 }
76
77 if (self.is_contiguous()) {
78 return self.strides()[1] == 1 ? Impl::ChannelsLast : Impl::Contiguous;
79 }
80
81 if (self.is_contiguous(at::MemoryFormat::ChannelsLast)) {
82 return Impl::ChannelsLast;
83 }
84
85 return Impl::General;
86 }
87
batch_norm_choose_impl(const Tensor & in1,const Tensor & in2)88 inline Impl batch_norm_choose_impl(const Tensor& in1, const Tensor& in2) {
89 auto imp1 = batch_norm_choose_impl(in1);
90 if (imp1 == Impl::General) {
91 return imp1;
92 }
93 auto imp2 = batch_norm_choose_impl(in2);
94 return imp1 == imp2 ? imp1 : Impl::General;
95 }
96
batch_norm_elementwise(const Tensor & out,const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean_,const Tensor & invstd_)97 void batch_norm_elementwise(
98 const Tensor& out, const Tensor& self, const std::optional<Tensor>& weight_opt,
99 const std::optional<Tensor>& bias_opt, const Tensor& mean_, const Tensor& invstd_) {
100 switch (batch_norm_choose_impl(self)) {
101 case Impl::Contiguous: {
102 c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
103 c10::MaybeOwned<Tensor> bias = at::borrow_from_optional_tensor(bias_opt);
104 resize_output(out, self.sizes());
105 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
106 "batch_norm_elementwise_cuda", [&] {
107 using accscalar_t = at::acc_type<scalar_t, true>;
108 const bool mixed_type = is_mixed_type(self, *weight, *bias);
109 if (mixed_type) {
110 batch_norm_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(
111 out, self, *weight, *bias, mean_, invstd_);
112 } else {
113 batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
114 out, self, *weight, *bias, mean_, invstd_);
115 }
116 });
117 return;
118 }
119 case Impl::ChannelsLast: {
120 auto weight = at::borrow_from_optional_tensor(weight_opt);
121 auto bias = at::borrow_from_optional_tensor(bias_opt);
122
123 if (resize_output_check(out, self.sizes())) {
124 resize_impl_cuda_(out.unsafeGetTensorImpl(), self.sizes(), self.strides());
125 }
126 if ((out.strides() == self.strides()) &&
127 (!weight->defined() || weight->is_contiguous()) &&
128 (!bias->defined() || bias->is_contiguous()) &&
129 (!mean_.defined() || mean_.is_contiguous()) &&
130 (!invstd_.defined() || invstd_.is_contiguous())) {
131 batch_norm_elemt_channels_last_cuda_template(
132 out, self, *weight, *bias, mean_, invstd_);
133 return;
134 }
135 [[fallthrough]];
136 }
137 case Impl::General: {
138 const int64_t ndim = self.dim();
139 DimVector sizes(ndim, 1), strides(ndim, 0);
140 // Helper to convert 1d tensors to an nd tensor that broadcasts with input
141 // All elements go into the channel dimension
142 auto as_nd = [&](const Tensor& t) {
143 TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
144 sizes[1] = t.sizes()[0];
145 strides[1] = t.strides()[0];
146 return t.as_strided(sizes, strides);
147 };
148
149 auto weight = weight_opt.has_value() && weight_opt->defined() ?
150 as_nd(*weight_opt) : at::scalar_tensor(1, mean_.options());
151 auto bias = bias_opt.has_value() && bias_opt->defined() ?
152 as_nd(*bias_opt) : at::scalar_tensor(0, mean_.options());
153 auto mean = as_nd(mean_);
154 auto invstd = as_nd(invstd_);
155
156 auto iter = TensorIteratorConfig()
157 .add_output(out)
158 .add_input(self)
159 .add_input(weight)
160 .add_input(bias)
161 .add_input(mean)
162 .add_input(invstd)
163 .check_all_same_dtype(false)
164 .promote_inputs_to_common_dtype(false)
165 .build();
166
167 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
168 "batch_norm_elementwise_cuda", [&] {
169 using acc_t = at::acc_type<scalar_t, true>;
170 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t input, acc_t weight, acc_t bias,
171 acc_t mean, acc_t invstd) -> scalar_t {
172 return (input - mean) * weight * invstd + bias;
173 });
174 });
175 return;
176 }
177 }
178 }
179
batch_norm_elementwise_backward_train(const Tensor & grad_out,const Tensor & input,const Tensor & mean,const Tensor & invstd,const Tensor & weight,const Tensor & sum_dy,const Tensor & sum_dy_xmu)180 Tensor batch_norm_elementwise_backward_train(
181 const Tensor& grad_out, const Tensor& input, const Tensor& mean, const Tensor& invstd,
182 const Tensor& weight, const Tensor& sum_dy, const Tensor& sum_dy_xmu) {
183 switch (batch_norm_choose_impl(input, grad_out)) {
184 case Impl::Contiguous: {
185 return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
186 "batch_norm_backward_elemt", [&] {
187 using accscalar_t = at::acc_type<scalar_t, true>;
188 const bool mixed_type = is_mixed_type(input, weight);
189 if (mixed_type) {
190 return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(
191 grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
192 } else {
193 return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
194 grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
195 }
196 });
197 }
198 case Impl::ChannelsLast: {
199 if ((!weight.defined() || weight.is_contiguous()) &&
200 mean.is_contiguous() && invstd.is_contiguous()) {
201 return batch_norm_backward_elemt_channels_last_cuda_template(
202 grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
203 }
204 [[fallthrough]];
205 }
206 case Impl::General: {
207 const auto ndim = input.dim();
208 DimVector sizes(ndim, 1), strides(ndim, 0);
209 auto as_nd = [&](const Tensor& t) {
210 TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
211 sizes[1] = t.sizes()[0];
212 strides[1] = t.strides()[0];
213 return t.as_strided(sizes, strides);
214 };
215 auto invstd_nd = as_nd(invstd);
216 auto mean_nd = as_nd(mean);
217 auto sum_dy_nd = as_nd(sum_dy);
218 auto sum_dy_xmu_nd = as_nd(sum_dy_xmu);
219 auto weight_nd = weight.defined() ? as_nd(weight) :
220 at::scalar_tensor(1.0, input.options().dtype(mean.scalar_type()));
221
222 Tensor grad_input = at::empty(input.sizes(), grad_out.options().memory_format(input.suggest_memory_format()));
223 auto iter = TensorIteratorConfig()
224 .add_output(grad_input)
225 .add_input(grad_out)
226 .add_input(input)
227 .add_input(weight_nd)
228 .add_input(mean_nd)
229 .add_input(invstd_nd)
230 .add_input(sum_dy_xmu_nd)
231 .add_input(sum_dy_nd)
232 .check_all_same_dtype(false)
233 .promote_inputs_to_common_dtype(false)
234 .build();
235
236 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
237 "batch_norm_eval_backward", [&]{
238 using accscalar_t = at::acc_type<scalar_t, true>;
239 auto norm_fct = static_cast<accscalar_t>(1.0 / (input.numel() /input.size(1)) );
240 gpu_kernel(iter, [norm_fct] GPU_LAMBDA (scalar_t gO, scalar_t input, accscalar_t weight,
241 accscalar_t mean, accscalar_t invstd,
242 accscalar_t xmu, accscalar_t dy) -> scalar_t {
243 auto factor_1_c = invstd * invstd * xmu * norm_fct;
244 auto factor_2_c = weight * invstd;
245 auto m_dy_c = dy * norm_fct;
246 return (gO - m_dy_c - (input - mean) * factor_1_c) * factor_2_c;
247 });
248 });
249 return grad_input;
250 }
251 }
252 TORCH_INTERNAL_ASSERT(false);
253 }
254
batch_norm_elementwise_backward_eval(const Tensor & grad_out,const Tensor & input,const Tensor & invstd,const Tensor & weight)255 Tensor batch_norm_elementwise_backward_eval(
256 const Tensor& grad_out, const Tensor& input,
257 const Tensor& invstd, const Tensor& weight) {
258 const auto ndim = input.dim();
259 DimVector shape(ndim, 1), strides(ndim, 0);
260 shape[1] = invstd.sizes()[0];
261 strides[1] = invstd.strides()[0];
262 auto invstd_nd = invstd.as_strided(shape, strides);
263 Tensor grad_input = at::empty(input.sizes(), grad_out.options());
264
265 if (weight.defined()) {
266 strides[1] = weight.strides()[0];
267 auto weight_nd = weight.as_strided(shape, strides);
268 auto iter = TensorIteratorConfig()
269 .add_output(grad_input)
270 .add_const_input(grad_out)
271 .add_const_input(invstd_nd)
272 .add_const_input(weight_nd)
273 .check_all_same_dtype(false)
274 .promote_inputs_to_common_dtype(false)
275 .build();
276
277 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
278 "batch_norm_eval_backward", [&]{
279 using accscalar_t = at::acc_type<scalar_t, true>;
280 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t gO, accscalar_t invstd, accscalar_t weight)
281 -> scalar_t {
282 return gO * weight * invstd;
283 });
284 });
285 } else {
286 auto iter = TensorIteratorConfig()
287 .add_output(grad_input)
288 .add_const_input(grad_out)
289 .add_const_input(invstd_nd)
290 .check_all_same_dtype(false)
291 .promote_inputs_to_common_dtype(false)
292 .build();
293
294 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_out.scalar_type(),
295 "batch_norm_eval_backward", [&]{
296 using accscalar_t = at::acc_type<scalar_t, true>;
297 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t gO, accscalar_t invstd) -> scalar_t {
298 return gO * invstd;
299 });
300 });
301 }
302 return grad_input;
303 }
304
305
batch_norm_mean_var(const Tensor & self,Tensor & save_mean,Tensor & save_var)306 void batch_norm_mean_var(const Tensor& self, Tensor& save_mean, Tensor& save_var) {
307 // NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored.
308 const double dummy_epsilon = 1e-5;
309 switch (batch_norm_choose_impl(self)) {
310 case Impl::Contiguous: {
311 AT_DISPATCH_FLOATING_TYPES_AND2(
312 kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
313 batch_norm_stats_cuda_template<scalar_t, int32_t, Var>(
314 save_mean, save_var, self, dummy_epsilon);
315 });
316 return;
317 }
318 case Impl::ChannelsLast: {
319 if ((!save_mean.defined() || save_mean.is_contiguous()) &&
320 (!save_var.defined() || save_var.is_contiguous())) {
321 AT_DISPATCH_FLOATING_TYPES_AND2(
322 kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
323 batch_norm_stats_channels_last_cuda_template<scalar_t, Var>(
324 save_mean, save_var, self, dummy_epsilon);
325 });
326 return;
327 }
328 [[fallthrough]];
329 }
330 case Impl::General: {
331 const int64_t ndim = self.dim();
332 DimVector reduce_dims(ndim - 1);
333 reduce_dims[0] = 0;
334 for (int64_t i = 2; i < ndim; ++i) {
335 reduce_dims[i - 1] = i;
336 }
337
338 // For some reason this isn't an actual operator but it exists anyway...
339 at::native::var_mean_out(save_var, save_mean, self, /*dims=*/reduce_dims,
340 /*unbiased=*/false, /*keepdim=*/false);
341 return;
342 }
343 }
344 }
345
batch_norm_update_stats(const Tensor & save_mean,const Tensor & save_var,const Tensor & running_mean,const Tensor & running_var,double momentum_,int64_t N)346 void batch_norm_update_stats(
347 const Tensor& save_mean, const Tensor& save_var,
348 const Tensor& running_mean, const Tensor& running_var,
349 double momentum_, int64_t N) {
350
351 auto iter = TensorIteratorConfig()
352 .add_output(running_mean)
353 .add_output(running_var)
354 .add_input(save_mean)
355 .add_input(save_var)
356 .add_input(running_mean)
357 .add_input(running_var)
358 .check_all_same_dtype(false)
359 .promote_inputs_to_common_dtype(false)
360 .build();
361
362 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
363 "batch_norm_update_stats_cuda", [&] {
364 using acc_t = at::acc_type<scalar_t, true>;
365 const auto bessel_correction_factor = static_cast<acc_t>(
366 static_cast<double>(N) / static_cast<double>(N - 1));
367 const auto momentum = static_cast<acc_t>(momentum_);
368 gpu_kernel_multiple_outputs(
369 iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
370 -> thrust::tuple<scalar_t, scalar_t> {
371 const auto unbiased_var = var * bessel_correction_factor;
372 return thrust::tuple<scalar_t, scalar_t>{
373 mean * momentum + (1 - momentum) * running_mean,
374 unbiased_var * momentum + (1 - momentum) * running_var,
375 };
376 });
377 });
378 }
379
batch_norm_update_stats_and_invert(const Tensor & save_mean,const Tensor & save_var,const Tensor & running_mean,const Tensor & running_var,double momentum_,double epsilon,int64_t N)380 void batch_norm_update_stats_and_invert(
381 const Tensor& save_mean, const Tensor& save_var,
382 const Tensor& running_mean, const Tensor& running_var,
383 double momentum_, double epsilon, int64_t N) {
384
385 auto iter = TensorIteratorConfig()
386 .add_output(running_mean)
387 .add_output(running_var)
388 .add_output(save_var)
389 .add_const_input(save_mean)
390 .add_input(save_var)
391 .add_input(running_mean)
392 .add_input(running_var)
393 .check_all_same_dtype(false)
394 .promote_inputs_to_common_dtype(false)
395 .build();
396
397 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
398 "batch_norm_update_stats_cuda", [&] {
399 using acc_t = at::acc_type<scalar_t, true>;
400 const auto bessel_correction_factor = static_cast<acc_t>(
401 static_cast<double>(N) / static_cast<double>(N - 1));
402 const auto eps = static_cast<acc_t>(epsilon);
403 const auto momentum = static_cast<acc_t>(momentum_);
404 gpu_kernel_multiple_outputs(
405 iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
406 -> thrust::tuple<scalar_t, scalar_t, acc_t> {
407 const auto unbiased_var = var * bessel_correction_factor;
408 return thrust::tuple<scalar_t, scalar_t, acc_t>{
409 mean * momentum + (1 - momentum) * running_mean,
410 unbiased_var * momentum + (1 - momentum) * running_var,
411 c10::cuda::compat::rsqrt(var + eps)
412 };
413 });
414 });
415 }
416
batch_norm_calc_invstd(const Tensor & out_invstd,const Tensor & running_var,double epsilon)417 void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var, double epsilon) {
418 auto iter = TensorIteratorConfig()
419 .add_output(out_invstd)
420 .add_input(running_var)
421 .check_all_same_dtype(false)
422 .build();
423
424 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_var.scalar_type(),
425 "batch_norm_invert_std_cuda", [&] {
426 using acc_t = at::acc_type<scalar_t, true>;
427 auto eps = static_cast<acc_t>(epsilon);
428 gpu_kernel(iter, [eps] GPU_LAMBDA (scalar_t var) -> acc_t {
429 return c10::cuda::compat::rsqrt(var + eps);
430 });
431 });
432 }
433 }
434
batch_norm_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)435 std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
436 const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
437 const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
438 TORCH_CHECK(has_running_mean == has_running_var);
439
440 if (train) {
441 batch_norm_mean_var(self, save_mean, save_invstd);
442 if (has_running_mean) {
443 const int64_t N = self.numel() / save_mean.numel();
444 batch_norm_update_stats_and_invert(
445 save_mean, save_invstd, *running_mean_opt, *running_var_opt,
446 momentum, epsilon, N);
447 } else {
448 batch_norm_calc_invstd(save_invstd, save_invstd, epsilon);
449 }
450 } else {
451 TORCH_CHECK(has_running_mean);
452 at::native::resize_output(save_mean, running_mean_opt->sizes());
453 save_mean.copy_(*running_mean_opt, /*non_blocking=*/true);
454 batch_norm_calc_invstd(save_invstd, running_var_opt.value(), epsilon);
455 }
456
457 batch_norm_elementwise(output, self, weight_opt, bias_opt, save_mean, save_invstd);
458 return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd);
459 }
460
batch_norm_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double epsilon)461 std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) {
462 auto output = at::empty_like(self);
463 int64_t n_input = self.size(1);
464 auto options = self.options().dtype(
465 at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
466 auto save_mean = at::empty({n_input}, options);
467 auto save_invstd = at::empty({n_input}, options);
468
469 at::native::batch_norm_cuda_out(
470 self,
471 weight_opt,
472 bias_opt,
473 running_mean_opt,
474 running_var_opt,
475 train,
476 momentum,
477 epsilon,
478 output,
479 save_mean,
480 save_invstd);
481 return std::make_tuple(output, save_mean, save_invstd);
482 }
483
_batch_norm_with_update_cuda(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)484 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
485 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
486 Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
487 // See [Note: hacky wrapper removal for optional tensor]
488 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
489 const Tensor& weight = *weight_maybe_owned;
490 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
491 Tensor output, save_mean, save_var, reserve;
492
493 BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
494 if (backend == BatchNormBackend::Cudnn) {
495 return at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
496 }
497 if (backend == BatchNormBackend::Miopen) {
498 reserve = at::empty({0}, input.options().dtype(kByte));
499 std::tie(output, save_mean, save_var) =
500 at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
501 } else {
502 reserve = at::empty({0}, input.options().dtype(kByte));
503 std::tie(output, save_mean, save_var) =
504 batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
505 }
506 return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
507 }
508
_batch_norm_with_update_cuda_out(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var,Tensor & reserve)509 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
510 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
511 Tensor& running_mean, Tensor& running_var, double momentum, double eps,
512 Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
513 // See [Note: hacky wrapper removal for optional tensor]
514 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
515 const Tensor& weight = *weight_maybe_owned;
516 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
517
518 BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
519 if (backend == BatchNormBackend::Cudnn) {
520 std::tie(out, save_mean, save_var, reserve) =
521 at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
522 } else if (backend == BatchNormBackend::Miopen) {
523 std::tie(out, save_mean, save_var) =
524 at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
525 } else {
526 std::tie(out, save_mean, save_var) =
527 batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
528 }
529 return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
530 }
531
_batch_norm_legit_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double epsilon)532 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
533 return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
534 }
535
_batch_norm_legit_no_stats_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double epsilon)536 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon) {
537 return batch_norm_cuda(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon);
538 }
539
_batch_norm_legit_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)540 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
541 return batch_norm_cuda_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_invstd);
542 }
543
_batch_norm_legit_no_stats_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double epsilon,Tensor & output,Tensor & save_mean,Tensor & save_invstd)544 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
545 return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
546 }
547
_new_batch_norm_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)548 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
549 const Tensor& grad_output, const Tensor& input, const Tensor& weight,
550 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
551 const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
552 bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
553 const Tensor& dummy_bias = at::empty(1);
554 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
555 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
556 const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
557 const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
558
559 BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
560
561 if (backend == BatchNormBackend::Cudnn) {
562 return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
563 } else if (backend == BatchNormBackend::Miopen) {
564 return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
565 } else {
566 return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
567 }
568 }
569
batch_norm_backward_cuda(const Tensor & grad_out,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double epsilon,std::array<bool,3> grad_input_mask)570 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
571 // See [Note: hacky wrapper removal for optional tensor]
572 c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
573 c10::MaybeOwned<Tensor> save_mean = at::borrow_from_optional_tensor(save_mean_opt);
574 c10::MaybeOwned<Tensor> save_invstd = at::borrow_from_optional_tensor(save_invstd_opt);
575 c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt);
576 c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt);
577
578 const bool needs_reduction = train || grad_input_mask[1] || grad_input_mask[2];
579
580 // Fused reduction & elementwise kernel
581 if (needs_reduction && grad_input_mask[0] &&
582 !batch_norm_use_channels_last_kernels(input) &&
583 cuda::detail::canUse32BitIndexMath(input) &&
584 cuda::detail::canUse32BitIndexMath(grad_out)) {
585 return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
586 "batch_norm_backward_cuda", [&] {
587 using accscalar_t = at::acc_type<scalar_t, true>;
588 const bool mixed_type = is_mixed_type(input, *weight, *running_mean, *running_var);
589 if (mixed_type) {
590 return batch_norm_backward_cuda_template<scalar_t, accscalar_t, int32_t>(
591 grad_out, input, *weight, *running_mean, *running_var,
592 *save_mean, *save_invstd, train, epsilon, grad_input_mask);
593 } else {
594 return batch_norm_backward_cuda_template<scalar_t, scalar_t, int32_t>(
595 grad_out, input, *weight, *running_mean, *running_var,
596 *save_mean, *save_invstd, train, epsilon, grad_input_mask);
597 }
598 });
599 }
600
601 // NOTE: native_batch_norm always returns save_mean and save_invstd to be reused in backward.
602 // However, this is also called from cudnn_batch_norm in eval mode which doesn't give
603 // save_mean and save_invstd, so it needs recalculated.
604 const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
605 Tensor mean;
606 TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n");
607 if (save_mean->numel() != 0) {
608 mean = *save_mean;
609 } else if (needs_reduction) {
610 TORCH_CHECK(!train && running_mean->defined());
611 mean = (running_mean->scalar_type() == acc_type) ?
612 *running_mean : running_mean->to(acc_type);
613 }
614
615 Tensor invstd;
616 TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n");
617 if (save_invstd->numel() != 0) {
618 invstd = *save_invstd;
619 } else {
620 TORCH_CHECK(!train && running_var->defined());
621 auto n_channels = input.sizes()[1];
622 invstd = at::empty({n_channels}, input.options().dtype(acc_type));
623 batch_norm_calc_invstd(invstd, *running_var, epsilon);
624 }
625
626 Tensor sum_dy, sum_dy_xmu, grad_weight, grad_bias;
627 if (needs_reduction) {
628 std::tie(sum_dy, sum_dy_xmu, grad_weight, grad_bias) =
629 batch_norm_backward_reduce_cuda(
630 grad_out, input, mean, invstd, *weight,
631 grad_input_mask[0], grad_input_mask[1], grad_input_mask[2]);
632 }
633
634 Tensor grad_input;
635 if (grad_input_mask[0]) {
636 if (train) {
637 // NOTE: sum_dy and sum_dy_xmy are defined, as train implies needs_reduction
638 grad_input = batch_norm_elementwise_backward_train(
639 grad_out, input, mean, invstd, *weight, sum_dy, sum_dy_xmu);
640 } else {
641 grad_input = batch_norm_elementwise_backward_eval(
642 grad_out, input, invstd, *weight);
643 }
644 }
645
646 return std::make_tuple(grad_input, grad_weight, grad_bias);
647 }
648
batch_norm_stats_cuda(const Tensor & self,double epsilon)649 std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
650 auto options = self.options().dtype(
651 at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
652 auto n_channels = self.size(1);
653 auto save_mean = at::empty({n_channels}, options);
654 auto save_invstd = at::empty({n_channels}, options);
655
656 bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self);
657 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
658 self.scalar_type(), "batch_norm_stats_cuda", [&] {
659 if (cuda::detail::canUse32BitIndexMath(self)) {
660 if (use_channels_last_kernel) {
661 batch_norm_stats_channels_last_cuda_template<scalar_t, InvStd>(
662 save_mean, save_invstd, self, epsilon);
663 } else {
664 batch_norm_stats_cuda_template<scalar_t, int32_t, InvStd>(
665 save_mean, save_invstd, self, epsilon);
666 }
667 } else {
668 batch_norm_stats_cuda_template<scalar_t, int64_t, InvStd>(
669 save_mean, save_invstd, self, epsilon);
670 }
671 });
672 return std::tuple<Tensor, Tensor>(save_mean, save_invstd);
673 }
674
batch_norm_elemt_cuda(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & invstd,double epsilon)675 Tensor batch_norm_elemt_cuda(
676 const Tensor& self, const std::optional<Tensor>& weight_opt,
677 const std::optional<Tensor>& bias_opt, const Tensor& mean,
678 const Tensor& invstd, double epsilon) {
679 auto output = at::empty_like(self);
680 // FIXME: Epsilon parameter isn't required, we don't take the reciprocal
681 batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
682 return output;
683 }
684
batch_norm_elemt_cuda_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & invstd,double epsilon,Tensor & output)685 Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
686 const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) {
687 // FIXME: Epsilon parameter isn't required, we don't take the reciprocal
688 batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
689 return output;
690 }
691
692 // accepting input(self) here to determine template data types, since running_mean/running_var are optional
batch_norm_gather_stats_cuda(const Tensor & self,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double epsilon,int64_t count)693 std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, double momentum, double epsilon, int64_t count) {
694 // See [Note: hacky wrapper removal for optional tensor]
695 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
696 const Tensor& running_mean = *running_mean_maybe_owned;
697 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
698
699 std::vector<int64_t> counts(mean.size(0), count);
700 Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU));
701 counts_ = counts_.to(self.device()).to(running_mean.defined() ? running_mean.dtype() : self.dtype());
702 return batch_norm_gather_stats_with_counts_cuda(self, mean, invstd, running_mean, running_var, momentum, epsilon, counts_);
703 }
704
705
batch_norm_gather_stats_with_counts_cuda(const Tensor & self,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double epsilon,const Tensor & counts)706 std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_cuda(
707 const Tensor& self, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, double momentum, double epsilon, const Tensor& counts) {
708 // See [Note: hacky wrapper removal for optional tensor]
709 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
710 const Tensor& running_mean = *running_mean_maybe_owned;
711 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
712
713
714 auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type();
715 return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "batch_norm_update_stats_cuda", [&] {
716 using accscalar_t = at::acc_type<scalar_t, true>;
717 if (cuda::detail::canUse32BitIndexMath(self)) {
718 return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int32_t>(mean, invstd, running_mean, running_var, momentum, epsilon, counts);
719 } else {
720 return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int64_t>(mean, invstd, running_mean, running_var, momentum, epsilon, counts);
721 }
722 });
723 }
724
batch_norm_backward_reduce_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & weight_opt,bool input_g,bool weight_g,bool bias_g)725 std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& weight_opt, bool input_g, bool weight_g, bool bias_g) {
726 // See [Note: hacky wrapper removal for optional tensor]
727 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
728 const Tensor& weight = *weight_maybe_owned;
729
730 if (at::cuda::detail::canUse32BitIndexMath(grad_output) &&
731 batch_norm_use_channels_last_kernels(grad_output) &&
732 batch_norm_use_channels_last_kernels(input) &&
733 (!weight.defined() || weight.is_contiguous()) &&
734 mean.is_contiguous() && invstd.is_contiguous()){
735 return batch_norm_backward_reduce_cuda_channels_last_template(
736 grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
737 }
738
739 return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(), "batch_norm_backward_reduce", [&] {
740 auto mean_st = mean.dtype();
741 auto invstd_st = invstd.dtype();
742 TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
743 const bool mixed_type = is_mixed_type(input, weight);
744 using accscalar_t = at::acc_type<scalar_t, true>;
745
746 if (cuda::detail::canUse32BitIndexMath(grad_output)) {
747 if (mixed_type) {
748 return batch_norm_backward_reduce_cuda_template<scalar_t, accscalar_t, int32_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
749 } else {
750 return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int32_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
751 }
752 } else {
753 if (mixed_type) {
754 return batch_norm_backward_reduce_cuda_template<scalar_t, accscalar_t, int64_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
755 } else {
756 return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int64_t>(grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
757 }
758 }
759 });
760 }
761
batch_norm_backward_elemt_cuda(const Tensor & self,const Tensor & input,const Tensor & mean,const Tensor & invstd,const std::optional<Tensor> & weight_opt,const Tensor & sum_dy,const Tensor & sum_dy_xmu,const Tensor & count)762 Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const std::optional<Tensor>& weight_opt, const Tensor& sum_dy, const Tensor& sum_dy_xmu, const Tensor& count) {
763 // See [Note: hacky wrapper removal for optional tensor]
764 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
765 const Tensor& weight = *weight_maybe_owned;
766
767 if (at::cuda::detail::canUse32BitIndexMath(self) &&
768 batch_norm_use_channels_last_kernels(self) &&
769 batch_norm_use_channels_last_kernels(input)) {
770 return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
771 }
772
773 return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_elemt", [&] {
774 auto mean_st = mean.dtype();
775 auto invstd_st = invstd.dtype();
776 TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
777 bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
778 bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
779 using accscalar_t = at::acc_type<scalar_t, true>;
780 if (cuda::detail::canUse32BitIndexMath(self)) {
781 if (is_half_float || is_bfloat16_float) {
782 return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int32_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
783 } else {
784 return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int32_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
785 }
786 } else {
787 if (is_half_float || is_bfloat16_float) {
788 return batch_norm_backward_elemt_cuda_template<scalar_t, accscalar_t, int64_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
789 } else {
790 return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int64_t>(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
791 }
792 }
793 });
794 }
795
batch_norm_update_stats_cuda(const Tensor & self,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum)796 std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
797 const Tensor& self, const std::optional<Tensor>& running_mean_opt,
798 const std::optional<Tensor>& running_var_opt, double momentum) {
799 c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt);
800 c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt);
801
802 const int64_t n_input = self.size(1);
803
804 TORCH_CHECK(self.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", self.sizes());
805 auto options = self.options().dtype(
806 at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
807 auto save_mean = at::empty({n_input}, options);
808 auto save_var = at::empty({n_input}, options);
809
810 batch_norm_mean_var(self, save_mean, save_var);
811 TORCH_CHECK(running_mean->defined() == running_var->defined());
812 if (running_mean->defined()) {
813 const int64_t N = self.numel() / save_mean.numel();
814 batch_norm_update_stats(save_mean, save_var, *running_mean, *running_var, momentum, N);
815 }
816 return std::tuple<Tensor, Tensor>(save_mean, save_var);
817 }
818
819 } // namespace at::native
820