1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/TensorMeta.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12
13 #include <ATen/detail/CUDAHooksInterface.h>
14 #include <ATen/native/cpu/Loops.h>
15 #include <ATen/native/batch_norm.h>
16 #include <ATen/native/Normalization.h>
17 #include <ATen/native/Resize.h>
18 #include <ATen/native/cpu/mixed_data_type.h>
19 #include <c10/util/irange.h>
20 #include <ATen/OpMathType.h>
21
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_batch_norm_impl_index.h>
27 #include <ATen/ops/_batch_norm_impl_index_backward_native.h>
28 #include <ATen/ops/_batch_norm_impl_index_native.h>
29 #include <ATen/ops/_native_batch_norm_legit_native.h>
30 #include <ATen/ops/_native_batch_norm_legit_no_training.h>
31 #include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
32 #include <ATen/ops/_batch_norm_with_update.h>
33 #include <ATen/ops/_batch_norm_with_update_native.h>
34 #include <ATen/ops/_batch_norm_no_update.h>
35 #include <ATen/ops/_batch_norm_no_update_native.h>
36 #include <ATen/ops/batch_norm_backward_native.h>
37 #include <ATen/ops/alias.h>
38 #include <ATen/ops/batch_norm.h>
39 #include <ATen/ops/batch_norm_native.h>
40 #include <ATen/ops/batch_norm_update_stats_native.h>
41 #include <ATen/ops/cudnn_batch_norm.h>
42 #include <ATen/ops/cudnn_batch_norm_backward.h>
43 #include <ATen/ops/empty.h>
44 #include <ATen/ops/empty_like.h>
45 #include <ATen/ops/instance_norm_native.h>
46 #include <ATen/ops/linalg_vector_norm.h>
47 #include <ATen/ops/mean.h>
48 #include <ATen/ops/miopen_batch_norm.h>
49 #include <ATen/ops/miopen_batch_norm_backward.h>
50 #include <ATen/ops/mul.h>
51 #include <ATen/ops/native_batch_norm.h>
52 #include <ATen/ops/native_batch_norm_backward.h>
53 #include <ATen/ops/native_batch_norm_backward_native.h>
54 #include <ATen/ops/native_batch_norm_native.h>
55 #include <ATen/ops/_native_batch_norm_legit.h>
56 #include <ATen/ops/renorm_native.h>
57 #include <ATen/ops/sum.h>
58 #include <ATen/ops/sqrt.h>
59 #endif
60
61 #include <c10/core/SymIntArrayRef.h>
62 #include <utility>
63 #include <vector>
64
65 static const int MIOPEN_DIM_MAX = 5;
66
67 namespace at::meta {
68
TORCH_META_FUNC(renorm)69 TORCH_META_FUNC(renorm)(const Tensor& self, const Scalar& p, int64_t dim, const Scalar& maxnorm) {
70 TORCH_CHECK(!p.isComplex(), "renorm: p must be real-valued");
71 TORCH_CHECK(p.toDouble() > 0.0, "renorm: non-positive-norm not supported");
72 TORCH_CHECK(!maxnorm.isComplex(), "renorm: maxnorm must be real-valued");
73 TORCH_CHECK(maxnorm.toDouble() >= 0.0,
74 "renorm: expected maxnorm to be >= 0 but got ", maxnorm.toDouble());
75 const auto ndim = self.dim();
76 TORCH_CHECK(ndim > 1, "renorm: input needs at least 2 dimensions, got ", ndim, " dimensions");
77 set_output_raw_strided(0, self.sizes(), {}, self.options());
78 }
79
80 } // namespace at::meta
81
82 namespace at::native {
83
84 DEFINE_DISPATCH(batch_norm_cpu_stub);
85 DEFINE_DISPATCH(batch_norm_cpu_collect_stats_stub);
86 DEFINE_DISPATCH(batch_norm_cpu_backward_stub);
87 DEFINE_DISPATCH(renorm_scale_factor_stub);
88
89 namespace {
check_dims_match_num_input_features(const char * arg_name,const SymInt & expected,const SymInt & actual)90 void check_dims_match_num_input_features(const char* arg_name, const SymInt& expected, const SymInt& actual){
91 TORCH_CHECK(actual == expected,
92 arg_name, " should contain ", expected, " elements not ", actual);
93 }
94
repeat_if_defined(const Tensor & t,const SymInt & repeat)95 static inline Tensor repeat_if_defined(const Tensor& t, const SymInt& repeat) {
96 if (t.defined()) {
97 return t.repeat_symint(repeat);
98 }
99 return t;
100 }
101 }
102
103 template<typename T>
104 struct InvStd {
operator ()at::native::InvStd105 T operator()(T var, double epsilon) const {
106 T invstd = 0;
107 if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
108 invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
109 }
110 return invstd;
111 }
112 };
113
114 template<typename T>
115 struct Var {
operator ()at::native::Var116 T operator()(T var, double epsilon) const {
117 return var;
118 }
119 };
120
is_contiguous(const Tensor & t)121 static inline bool is_contiguous(const Tensor& t) {
122 return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) || t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
123 }
124
125 // For some ambiguous cases, it is possible a channels last contiguous Tensor has
126 // `suggest_memory_format` of Contiguous.
127 // See https://github.com/pytorch/pytorch/issues/63224 for details.
suggest_memory_format_contig(const Tensor & t)128 static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
129 return t.is_contiguous() ?
130 at::MemoryFormat::Contiguous : (t.is_contiguous(at::MemoryFormat::ChannelsLast3d) ?
131 at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast);
132 }
133
134 template<typename scalar_t, typename param_t>
batch_norm_cpu_transform_input_template(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & save_mean,const Tensor & save_invstd,const Tensor & running_mean,const Tensor & running_var,bool train,double eps,Tensor & output)135 std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
136 const Tensor& input, const Tensor& weight, const Tensor& bias,
137 const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
138 const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
139 bool train, double eps, Tensor& output) {
140
141 bool all_contiguous = is_contiguous(input)
142 && is_contiguous(output)
143 && (!weight.defined() || weight.is_contiguous())
144 && (!bias.defined() || bias.is_contiguous())
145 && running_mean.is_contiguous()
146 && running_var.is_contiguous();
147
148 // inference contiguous path
149 if (all_contiguous) {
150 if (input.numel() != 0) {
151 batch_norm_cpu_stub(kCPU, output, input, weight, bias,
152 save_mean, save_invstd, running_mean, running_var, train, eps);
153 }
154 return std::make_tuple(output, save_mean, save_invstd);
155 }
156
157 const int64_t ndim = input.dim();
158 // Helper to convert 1d tensors to an nd tensor that broadcasts with input
159 // All elements go into the channel dimension
160 DimVector sizes(ndim, 1), strides(ndim, 0);
161 auto as_nd = [&](const Tensor& t) {
162 TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
163 sizes[1] = t.sizes()[0];
164 strides[1] = t.strides()[0];
165 return t.as_strided(sizes, strides);
166 };
167
168 auto mean = as_nd(train ? save_mean : running_mean);
169 auto invstd = as_nd([&]{
170 if (train) {
171 return save_invstd;
172 } else {
173 return 1 / at::sqrt(running_var + eps);
174 }
175 }());
176 constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
177 const auto dtype = mixed_type ? kFloat : input.scalar_type();
178 auto w = weight.defined() ? as_nd(weight) :
179 at::detail::scalar_tensor_static(1, dtype, kCPU);
180 auto b = bias.defined() ? as_nd(bias) :
181 at::detail::scalar_tensor_static(0, dtype, kCPU);
182
183 auto iter = TensorIteratorConfig()
184 .add_output(output)
185 .add_input(input)
186 .add_input(mean)
187 .add_input(invstd)
188 .add_input(w)
189 .add_input(b)
190 .check_all_same_dtype(false)
191 .promote_inputs_to_common_dtype(false)
192 .build();
193 cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) -> scalar_t {
194 return ((input - mean) * invstd) * weight + bias;
195 });
196 return std::make_tuple(output, save_mean, save_invstd);
197 }
198
199 template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
batch_norm_cpu_update_stats_template(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps,Tensor & save_mean,Tensor & save_var_transform)200 std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
201 const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
202 double momentum, double eps, Tensor& save_mean, Tensor& save_var_transform) {
203
204 using accscalar_t = at::acc_type<scalar_t, false>;
205
206 int64_t n_input = input.size(1);
207 TORCH_CHECK(input.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", input.sizes());
208 int64_t n = input.numel() / n_input;
209
210 bool all_contiguous = is_contiguous(input);
211 constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
212 const auto dtype = mixed_type ? kFloat : input.scalar_type();
213
214 auto save_mean_a = save_mean.accessor<param_t, 1>();
215 auto save_var_transform_a = save_var_transform.accessor<param_t, 1>();
216
217 auto running_mean_a = conditional_accessor_1d<param_t>(running_mean);
218 auto running_var_a = conditional_accessor_1d<param_t>(running_var);
219
220 if (all_contiguous) {
221 auto _mean = at::empty({n_input}, input.options().dtype(dtype));
222 auto _var_sum = at::empty({n_input}, input.options().dtype(dtype));
223 auto _mean_a = _mean.accessor<param_t, 1>();
224 auto _var_sum_a = _var_sum.accessor<param_t, 1>();
225 auto momentum_ = static_cast<param_t>(momentum);
226
227 batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input);
228
229 parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
230 for (const auto f : c10::irange(b_begin, b_end)) {
231 save_mean_a[f] = _mean_a[f];
232 save_var_transform_a[f] = VarTransform<accscalar_t>{}(_var_sum_a[f] / n, eps);
233
234 if (running_mean.defined()) {
235 running_mean_a[f] = momentum_ * _mean_a[f] + (1 - momentum_) * running_mean_a[f];
236 }
237 if (running_var.defined()) {
238 accscalar_t unbiased_var = _var_sum_a[f] / (n - 1);
239 running_var_a[f] = momentum_ * unbiased_var + (1 - momentum_) * running_var_a[f];
240 }
241 }
242 });
243
244 return std::make_tuple(save_mean, save_var_transform);
245 }
246
247 // non-contiguous path
248 auto channel_stride = input.strides()[1];
249 auto in_data = input.data_ptr<scalar_t>();
250 auto reduce_iter = TensorIteratorConfig()
251 .add_input(input)
252 .resize_outputs(false)
253 .declare_static_shape(input.sizes(), /*squash_dims=*/1)
254 .check_all_same_dtype(false)
255 .promote_inputs_to_common_dtype(false)
256 .build();
257
258 parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
259 TensorIterator iter(reduce_iter);
260 for (const auto f : c10::irange(b_begin, b_end)) {
261 // compute variance per input
262 iter.unsafe_replace_operand(0, in_data + channel_stride * f);
263 accscalar_t var_sum = 0;
264 auto mean = static_cast<accscalar_t>(save_mean_a[f]);
265 cpu_serial_kernel(iter, [&](const scalar_t i) -> void {
266 var_sum += (i - mean) * (i - mean);
267 });
268 save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
269
270 // update running averages
271 if (running_mean.defined()) {
272 running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
273 }
274 if (running_var.defined()) {
275 accscalar_t unbiased_var = var_sum / (n - 1);
276 running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
277 }
278 }
279 });
280 return std::make_tuple(save_mean, save_var_transform);
281 }
282
283 template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
batch_norm_cpu_update_stats_template(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps)284 std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
285 const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
286 double momentum, double eps) {
287 int64_t n_input = input.size(1);
288 const int64_t ndim = input.dim();
289 DimVector reduce_dims(ndim - 1);
290 reduce_dims[0] = 0;
291 for (const auto i : c10::irange(2, ndim)) {
292 reduce_dims[i - 1] = i;
293 }
294
295 constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
296 const auto dtype = mixed_type ? kFloat : input.scalar_type();
297 Tensor save_mean = is_contiguous(input) ? at::empty({n_input}, input.options().dtype(dtype)) : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype);
298 Tensor save_var_transform = at::empty({n_input}, input.options().dtype(dtype));
299 return batch_norm_cpu_update_stats_template<scalar_t, param_t, VarTransform>(input, running_mean, running_var, momentum, eps, save_mean, save_var_transform);
300 }
301
302 template<typename scalar_t, typename param_t>
batch_norm_backward_cpu_template(const Tensor & grad_out_,const Tensor & input,const Tensor & weight,const Tensor & running_mean,const Tensor & running_var,const Tensor & save_mean,const Tensor & save_invstd,bool train,double eps,std::array<bool,3> grad_input_mask)303 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
304 const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
305 const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
306 bool train, double eps, std::array<bool,3> grad_input_mask) {
307
308 using accscalar_t = at::acc_type<scalar_t, false>;
309
310 constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
311 const auto dtype = mixed_type ? kFloat : input.scalar_type();
312
313 Tensor grad_input;
314 Tensor grad_weight;
315 Tensor grad_bias;
316 if (grad_input_mask[0]) {
317 grad_input = at::empty_like(input, input.suggest_memory_format());
318 }
319 if (grad_input_mask[1]) {
320 grad_weight = at::empty({input.size(1)}, input.options().dtype(dtype));
321 }
322 if (grad_input_mask[2]) {
323 grad_bias = at::empty({input.size(1)}, input.options().dtype(dtype));
324 }
325
326 // since we are directly manipulating pointers in contiguous path,
327 // need to make sure input and grad_out have the same memory format.
328 bool all_contiguous = is_contiguous(input)
329 && is_contiguous(grad_out_)
330 && input.suggest_memory_format() == grad_out_.suggest_memory_format();
331
332 if (all_contiguous) {
333 if (grad_input_mask[0]) {
334 grad_input = at::empty_like(input, suggest_memory_format_contig(input));
335 }
336 batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
337 grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
338 return std::make_tuple(grad_input, grad_weight, grad_bias);
339 }
340
341 auto weight_a = conditional_accessor_1d<const param_t>(weight);
342 auto grad_weight_a = conditional_accessor_1d<param_t>(grad_weight);
343 auto grad_bias_a = conditional_accessor_1d<param_t>(grad_bias);
344
345 int64_t n_input = input.size(1);
346 int64_t n = input.numel() / n_input;
347
348 auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
349 auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
350
351 auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
352 auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
353
354 const int64_t ndim = input.dim();
355
356 // Reduce all dimensions except dim=1
357 DimVector reduce_dims(ndim - 1);
358 reduce_dims[0] = 0;
359 for (const auto i : c10::irange(2, ndim)) {
360 reduce_dims[i - 1] = i;
361 }
362
363 auto sum = at::sum(grad_out_, /*dim=*/reduce_dims);
364 auto sum_a = sum.accessor<scalar_t, 1>();
365
366 auto reduce_iter = TensorIteratorConfig()
367 .add_const_input(input)
368 .add_const_input(grad_out_)
369 .resize_outputs(false)
370 .declare_static_shape(input.sizes(), /*squash_dims=*/1)
371 .build();
372
373 TensorIterator unary_iter;
374 TensorIterator binary_iter;
375 if (grad_input_mask[0]) {
376 unary_iter.build(
377 TensorIteratorConfig()
378 .add_output(grad_input)
379 .add_const_input(train ? input : grad_out_)
380 .resize_outputs(false)
381 .declare_static_shape(input.sizes(), /*squash_dims=*/1));
382
383 if (train) {
384 binary_iter.build(
385 TensorIteratorConfig()
386 .add_output(grad_input)
387 .add_input(grad_input)
388 .add_const_input(grad_out_)
389 .resize_outputs(false)
390 .declare_static_shape(input.sizes(), /*squash_dims=*/1));
391 }
392 }
393
394 auto in_channel_stride = input.strides()[1];
395 auto in_data = input.const_data_ptr<scalar_t>();
396 auto grad_in_channel_stride = grad_input_mask[0] ? grad_input.strides()[1] : 0;
397 auto grad_in_data = grad_input_mask[0] ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
398 auto grad_out_channel_stride = grad_out_.strides()[1];
399 auto grad_out_data = grad_out_.const_data_ptr<scalar_t>();
400
401 parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
402 TensorIterator reduce_iter_local(reduce_iter);
403 TensorIterator unary_iter_local(unary_iter);
404 TensorIterator binary_iter_local(binary_iter);
405
406 for (const auto f : c10::irange(b_begin, b_end)) {
407 param_t w = weight.defined() ? weight_a[f] : param_t(1);
408
409 param_t mean{}, invstd{};
410 if (train) {
411 mean = save_mean_a[f];
412 invstd = save_invstd_a[f];
413 } else {
414 mean = running_mean_a[f];
415 invstd = 1 / std::sqrt(running_var_a[f] + eps);
416 }
417
418 // dot product of the Q(X) and gradOutput
419 accscalar_t dotp = 0;
420 reduce_iter_local.unsafe_replace_operand(
421 0, const_cast<scalar_t*>(in_data + f * in_channel_stride));
422 reduce_iter_local.unsafe_replace_operand(
423 1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
424
425 cpu_serial_kernel(reduce_iter_local, [&](const scalar_t i, const scalar_t go) -> void {
426 dotp += (i - mean) * go;
427 });
428
429 if (grad_input_mask[0]) {
430 if (train) {
431 // when in training mode
432 // Q(X) = X - E[x] ; i.e. input centered to zero mean
433 // Y = Q(X) / sigma ; i.e. BN output before weight and bias
434 // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w
435
436 // projection of gradOutput on to output scaled by std
437 scalar_t k = (scalar_t) dotp * invstd * invstd / n;
438 {
439 unary_iter_local.unsafe_replace_operand(
440 0, grad_in_data + f * grad_in_channel_stride);
441 unary_iter_local.unsafe_replace_operand(
442 1, const_cast<scalar_t*>(in_data + f * in_channel_stride));
443 cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
444 return (i - mean) * k;
445 });
446 }
447
448 scalar_t grad_mean = sum_a[f] / n;
449 {
450 auto gI_data = grad_in_data + f * grad_in_channel_stride;
451 binary_iter_local.unsafe_replace_operand(0, gI_data);
452 binary_iter_local.unsafe_replace_operand(1, gI_data);
453 binary_iter_local.unsafe_replace_operand(
454 2, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
455 cpu_serial_kernel(binary_iter_local, [&](scalar_t gi, scalar_t go) -> scalar_t {
456 return (go - grad_mean - gi) * invstd * w;
457 });
458 }
459 } else {
460 // when in evaluation mode
461 // Q(X) = X - running_mean ; i.e. input centered to zero mean
462 // Y = Q(X) / running_std ; i.e. BN output before weight and bias
463 // dL/dX = w / running_std
464 {
465 unary_iter_local.unsafe_replace_operand(
466 0, grad_in_data + f * grad_in_channel_stride);
467 unary_iter_local.unsafe_replace_operand(
468 1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
469 cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
470 return i * invstd * w;
471 });
472 }
473 }
474 }
475 if (grad_input_mask[1]) {
476 grad_weight_a[f] = dotp * invstd;
477 }
478
479 if (grad_input_mask[2]) {
480 grad_bias_a[f] = sum_a[f];
481 }
482 }
483 });
484 return std::make_tuple(grad_input, grad_weight, grad_bias);
485 }
486
_select_batch_norm_backend(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & running_mean,const Tensor & running_var,bool training,double eps)487 BatchNormBackend _select_batch_norm_backend(
488 const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
489 const Tensor& running_var, bool training, double eps) {
490
491 auto& ctx = at::globalContext();
492 bool cudnn_enabled = ctx.userEnabledCuDNN();
493
494 if (
495 input.is_cuda()
496 && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
497 && (input.scalar_type() != at::kHalf
498 || weight.scalar_type() == at::kFloat)
499 && weight.defined() && bias.defined()
500 && ((running_mean.defined() && running_var.defined())
501 || (!running_mean.defined() && !running_var.defined() && training))
502 && (input.dim() >= 3)
503 && ((input.sym_size(0) <= 880801 && training) // spatial, training
504 ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
505 && detail::getCUDAHooks().compiledWithCuDNN()
506 && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
507 && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
508 && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
509 ) {
510 return BatchNormBackend::Cudnn;
511 }
512
513 if (
514 input.is_cuda()
515 && input.dim() <= MIOPEN_DIM_MAX
516 && input.scalar_type() != at::kDouble
517 && input.scalar_type() != at::kBFloat16
518 && (weight.scalar_type() != at::kHalf)
519 && weight.defined() && bias.defined()
520 && ((running_mean.defined() && running_var.defined())
521 || (!running_mean.defined() && !running_var.defined() && training))
522 && (input.dim() >= 3)
523 && detail::getCUDAHooks().compiledWithMIOpen()
524 && cudnn_enabled
525 && input.suggest_memory_format() != MemoryFormat::ChannelsLast
526 && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
527 ) {
528 return BatchNormBackend::Miopen;
529 }
530
531 return BatchNormBackend::Native;
532 }
533
534
535 // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
536 // of backends, while enabling it to keep the information about the used backend, so that it can
537 // use its corresponding backward implementation.
538 // XXX: The indices of backends need to be kept synchronized between this function and its _backward.
539 // TODO: remove cudnn_enabled arg
_batch_norm_impl_index(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double momentum,double eps,bool cudnn_enabled)540 std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
541 const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
542 bool training, double momentum, double eps, bool cudnn_enabled) {
543 // See [Note: hacky wrapper removal for optional tensor]
544 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
545 const Tensor& weight = *weight_maybe_owned;
546 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
547 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
548 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
549
550 auto num_features = input.sym_sizes()[1];
551
552 if (input.sym_numel() == 0) {
553 Tensor reserve = at::empty({0}, input.options().dtype(kByte));
554 auto options = input.options().dtype(
555 at::toAccumulateType(input.scalar_type(), input.device().type()));
556 auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
557 auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options);
558
559 // don't return view of input, don't return empty tensor because it will break gradient chain
560 auto out = input.clone();
561 if (weight.defined()) out = out * weight[0];
562 if (bias.defined()) out = out + bias[0];
563 return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
564 out, save_mean, save_invstd, reserve, 0);
565 }
566
567 if (running_mean.defined()) {
568 check_dims_match_num_input_features("running_mean", num_features, running_mean.sym_numel());
569 } else if (!training) {
570 AT_ERROR("running_mean must be defined in evaluation mode");
571 }
572 if (running_var.defined()) {
573 check_dims_match_num_input_features("running_var", num_features, running_var.sym_numel());
574 } else if (!training) {
575 AT_ERROR("running_var must be defined in evaluation mode");
576 }
577 if (weight.defined()) {
578 check_dims_match_num_input_features("weight", num_features, weight.sym_numel());
579 }
580 if (bias.defined()) {
581 check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
582 }
583
584 BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
585
586 if (backend == BatchNormBackend::Cudnn) {
587 auto input_c = input.contiguous(input.suggest_memory_format());
588 auto weight_c = weight.contiguous();
589 auto bias_c = bias.contiguous();
590 auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
591 auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
592
593 auto [output, save_mean, save_var, reserve] =
594 at::cudnn_batch_norm(input_c, weight_c, bias_c, rmean_c, rvar_c,
595 training, momentum, eps);
596
597 return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
598 output, save_mean, save_var, reserve, 1);
599 }
600
601 Tensor reserve = at::empty({0}, input.options().dtype(kByte));
602
603 if (backend == BatchNormBackend::Miopen) {
604 return std::tuple_cat(
605 at::miopen_batch_norm(
606 input.contiguous(), weight.contiguous(), bias.contiguous(),
607 running_mean.defined() ? running_mean.contiguous() : running_mean,
608 running_var.defined() ? running_var.contiguous() : running_var,
609 training, momentum, eps),
610 std::tuple<Tensor>(reserve),
611 std::make_tuple(2));
612 }
613
614 return std::tuple_cat(
615 at::native_batch_norm(
616 input, weight, bias, running_mean, running_var, training, momentum, eps),
617 std::tuple<Tensor>(reserve),
618 std::make_tuple(0));
619 }
620
_batch_norm_impl_index_backward(int64_t impl_index,const Tensor & input,const Tensor & grad_output,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_transform_opt,bool train,double epsilon,std::array<bool,3> output_mask,const Tensor & reservedSpace)621 std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
622 int64_t impl_index,
623 const Tensor& input, const Tensor& grad_output, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, const std::optional<Tensor>& save_mean_opt /* optional */, const std::optional<Tensor>& save_var_transform_opt /* optional */,
624 bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
625 // See [Note: hacky wrapper removal for optional tensor]
626 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
627 const Tensor& weight = *weight_maybe_owned;
628 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
629 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
630 const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
631 const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();});
632
633 if (input.numel() == 0) {
634 std::vector<int64_t> dims(input.dim() - 1);
635 dims[0] = 0;
636 std::iota(dims.begin() + 1, dims.end(), 2);
637
638 // don't return empty tensor because it will break gradient chain
639 Tensor grad_input;
640 Tensor grad_weight;
641 Tensor grad_bias;
642 if (output_mask[2]) {
643 grad_bias = grad_output.sum(dims);
644 }
645 if (output_mask[1]) {
646 grad_weight = (grad_output * input).sum(dims);
647 }
648 if (output_mask[0] && weight.defined()) {
649 grad_input = grad_output * weight[0];
650 }
651 return std::make_tuple(grad_input, grad_weight, grad_bias);
652 }
653
654 // backward in inference mode is not supported in cudnn, fallback to native
655 if (impl_index == 0 || (!train)) {
656 return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
657 } else if (impl_index == 1) {
658 // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
659 // format conversion is done inside cudnn_batch_norm_backward instead
660 return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
661 } else if (impl_index == 2) {
662 return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
663 }
664 TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
665 }
666
667 // TODO: remove cudnn_enabled arg
batch_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double momentum,double eps,bool cudnn_enabled)668 Tensor batch_norm(
669 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
670 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
671 bool training, double momentum, double eps, bool cudnn_enabled) {
672 const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
673 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
674 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
675 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
676 return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
677 training, momentum, eps, cudnn_enabled));
678 // TODO: switch to the new stack after the 2 week FC window
679 // if (training) {
680 // BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
681 // if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
682 // auto input_c = input;
683 // if (backend == BatchNormBackend::Cudnn) {
684 // input_c = input.contiguous(input.suggest_memory_format());
685 // } else {
686 // input_c = input.contiguous();
687 // }
688 // auto weight_c = weight.contiguous();
689 // auto bias_c = bias.contiguous();
690 // auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
691 // auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
692 // return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
693 // const_cast<Tensor&>(rvar_c), momentum, eps));
694 // } else {
695 // return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
696 // const_cast<Tensor&>(running_var), momentum, eps));
697 // }
698 // } else {
699 // return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
700 // momentum, eps));
701 // }
702 }
703
instance_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool use_input_stats,double momentum,double eps,bool cudnn_enabled)704 Tensor instance_norm(
705 const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
706 bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
707 // See [Note: hacky wrapper removal for optional tensor]
708 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
709 const Tensor& weight = *weight_maybe_owned;
710 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
711 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
712 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
713
714 TORCH_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()),
715 "Expected running_mean and running_var to be defined when use_input_stats is false");
716 std::vector<SymInt> shape = input.sym_sizes().vec();
717 SymInt b = input.sym_size(0);
718 SymInt c = input.sym_size(1);
719 shape[1] = b * c;
720 shape[0] = SymInt(1);
721
722 Tensor weight_ = repeat_if_defined(weight, b);
723 Tensor bias_ = repeat_if_defined(bias, b);
724 Tensor running_mean_ = repeat_if_defined(running_mean, b);
725 Tensor running_var_ = repeat_if_defined(running_var, b);
726
727 auto input_reshaped = input.contiguous().view_symint(shape);
728 auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_,
729 use_input_stats, momentum, eps, cudnn_enabled);
730
731 // we alias running_mean and running_var because they are const but we want to modify their data
732 if (running_mean.defined()) {
733 at::alias(running_mean).copy_(running_mean_.view_symint({ b, c }).mean(0, false));
734 }
735 if (running_var.defined()) {
736 at::alias(running_var).copy_(running_var_.view_symint({ std::move(b), std::move(c) }).mean(0, false));
737 }
738
739 return out.view_symint(input.sym_sizes());
740 }
741
batch_norm_update_stats_cpu(const Tensor & self,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum)742 std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
743 const Tensor& self, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, double momentum) {
744 // See [Note: hacky wrapper removal for optional tensor]
745 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
746 const Tensor& running_mean = *running_mean_maybe_owned;
747 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
748
749 const bool mixed_type = is_mixed_type(self, running_mean, running_var);
750 return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_update_stats_cpu", [&] {
751 using opmath_t = at::opmath_type<scalar_t>;
752 if (mixed_type) {
753 check_mixed_data_type(self, running_mean, running_var);
754 return batch_norm_cpu_update_stats_template<scalar_t, opmath_t, Var>(self, running_mean, running_var, momentum, 0);
755 } else {
756 return batch_norm_cpu_update_stats_template<scalar_t, scalar_t, Var>(self, running_mean, running_var, momentum, 0);
757 }
758 });
759 }
760
batch_norm_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)761 std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
762 bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
763 // See [Note: hacky wrapper removal for optional tensor]
764 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
765 const Tensor& weight = *weight_maybe_owned;
766 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
767 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
768 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
769
770 checkBackend("batch_norm_cpu_out", {self, weight, bias, running_mean, running_var}, Backend::CPU);
771 // Resize out
772 at::native::resize_output(out, self.sizes());
773
774 const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var);
775 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm", [&] {
776 using opmath_t = at::opmath_type<scalar_t>;
777 if (mixed_type) {
778 check_mixed_data_type(self, weight, bias, running_mean, running_var);
779 if (!train) {
780 return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
781 } else {
782 // Resize save_mean and save_var
783 at::native::resize_output(save_mean, {self.size(1)});
784 at::native::resize_output(save_var, {self.size(1)});
785 auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, opmath_t, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
786 return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
787 }
788 } else {
789 if (!train) {
790 return batch_norm_cpu_transform_input_template<scalar_t, scalar_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
791 } else {
792 // Resize save_mean and save_var
793 at::native::resize_output(save_mean, {self.size(1)});
794 at::native::resize_output(save_var, {self.size(1)});
795 auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, scalar_t, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
796 return batch_norm_cpu_transform_input_template<scalar_t, scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
797 }
798 }
799 });
800
801 return std::tuple<Tensor& ,Tensor&, Tensor&>(out, save_mean, save_var);
802 }
803
batch_norm_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)804 std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
805 bool train, double momentum, double eps) {
806 // See [Note: hacky wrapper removal for optional tensor]
807 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
808 const Tensor& weight = *weight_maybe_owned;
809 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
810 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
811 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
812
813 checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU);
814
815 // Prepare output tensor
816 const bool all_contiguous = is_contiguous(self)
817 && (!weight.defined() || weight.is_contiguous())
818 && (!bias.defined() || bias.is_contiguous())
819 && running_mean.is_contiguous()
820 && running_var.is_contiguous();
821 Tensor output = at::empty_like(self, all_contiguous ? suggest_memory_format_contig(self) : self.suggest_memory_format());
822
823 // Prepare save_mean and save_var
824 Tensor save_var;
825 Tensor save_mean;
826 const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var);
827 const int64_t ndim = self.dim();
828 DimVector reduce_dims(ndim - 1);
829 reduce_dims[0] = 0;
830 for (const auto i : c10::irange(2, ndim)) {
831 reduce_dims[i - 1] = i;
832 }
833 if (mixed_type) {
834 if (!train) {
835 save_mean = at::empty({0}, self.options().dtype(kFloat));
836 save_var = at::empty({0}, self.options().dtype(kFloat));
837 } else {
838 save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options().dtype(kFloat)) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false, kFloat);
839 save_var = at::empty({self.size(1)}, self.options().dtype(kFloat));
840 }
841 } else {
842 if (!train) {
843 save_mean = at::empty({0}, self.options());
844 save_var = at::empty({0}, self.options());
845 } else {
846 save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options()) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false);
847 save_var = at::empty({self.size(1)}, self.options());
848 }
849 }
850 return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
851 }
852
_batch_norm_with_update_cpu(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)853 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
854 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
855 Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
856 auto [output, save_mean, save_var] =
857 batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
858 Tensor reserve = at::empty({0}, input.options().dtype(kByte));
859 return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
860 }
861
_batch_norm_with_update_cpu_out(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var,Tensor & reserve)862 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
863 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
864 Tensor& running_mean, Tensor& running_var, double momentum, double eps,
865 Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
866 std::tie(out, save_mean, save_var) =
867 batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
868 return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
869 }
870
871
_batch_norm_no_update(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double eps)872 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
873 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
874 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
875 double momentum, double eps) {
876 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
877 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
878 auto [output, save_mean, save_var] =
879 batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
880 Tensor reserve = at::empty({0}, input.options().dtype(kByte));
881 return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
882 }
883
_batch_norm_legit_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)884 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
885 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
886 Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
887 return batch_norm_cpu(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
888 }
889
_batch_norm_legit_no_stats_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)890 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cpu(
891 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
892 bool train, double momentum, double eps) {
893 return batch_norm_cpu(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
894 }
_batch_norm_legit_no_training(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps)895 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_training(
896 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
897 const Tensor& running_mean, const Tensor& running_var, double momentum, double eps) {
898 return at::_native_batch_norm_legit(self, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*train=*/false, momentum, eps);
899 }
900
901
_batch_norm_legit_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)902 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
903 return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps, out, save_mean, save_var);
904 }
905
906
_batch_norm_legit_no_stats_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)907 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
908 return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
909 }
910
_new_batch_norm_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)911 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
912 const Tensor& grad_output, const Tensor& input, const Tensor& weight,
913 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
914 const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
915 bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
916 return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
917 }
918
batch_norm_backward_cpu(const Tensor & grad_out,const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)919 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
920 bool train, double eps, std::array<bool,3> grad_input_mask) {
921 // See [Note: hacky wrapper removal for optional tensor]
922 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
923 const Tensor& weight = *weight_maybe_owned;
924 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
925 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
926 const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
927 const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();});
928
929 const bool mixed_type = is_mixed_type(self, weight, running_mean, running_var, save_mean, save_invstd);
930 return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_backward_cpu", [&] {
931 using opmath_t = at::opmath_type<scalar_t>;
932 if (mixed_type) {
933 check_mixed_data_type(self, weight, running_mean, running_var, save_mean, save_invstd);
934 return batch_norm_backward_cpu_template<scalar_t, opmath_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
935 } else {
936 return batch_norm_backward_cpu_template<scalar_t, scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
937 }
938 });
939 }
940
TORCH_IMPL_FUNC(renorm_out)941 TORCH_IMPL_FUNC(renorm_out)(const Tensor& self, const Scalar& p, int64_t dim,
942 const Scalar& maxnorm, const Tensor& out) {
943 auto self_sizes = self.sizes();
944 dim = c10::maybe_wrap_dim(dim, self_sizes.size());
945
946 DimVector reduce_dims(self_sizes.size());
947 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
948 reduce_dims.erase(reduce_dims.begin() + dim);
949
950 // For cuda half, calculate norm in float precision then cast
951 // normalization factor to half
952 auto dtype = self.scalar_type();
953 auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
954 Tensor norm;
955 if (acc_type != dtype) {
956 norm = at::linalg_vector_norm(self, p.toDouble(), reduce_dims,
957 /*keepdim=*/true, /*dtype=*/acc_type);
958 } else {
959 norm = at::linalg_vector_norm(self, p.toDouble(), reduce_dims,
960 /*keepdim=*/true);
961 }
962
963 auto factor = (acc_type == c10::toRealValueType(dtype)) ?
964 norm : at::empty(norm.sizes(), self.options());
965 auto iter = TensorIteratorConfig()
966 .add_output(factor)
967 .add_input(norm)
968 .set_check_mem_overlap(false)
969 .cast_common_dtype_to_outputs(true)
970 .build();
971
972 renorm_scale_factor_stub(iter.device_type(), iter, maxnorm.toDouble());
973 at::mul_outf(self, factor, const_cast<Tensor&>(out));
974 }
975
976 } // at::native
977