xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Normalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/TensorMeta.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12 
13 #include <ATen/detail/CUDAHooksInterface.h>
14 #include <ATen/native/cpu/Loops.h>
15 #include <ATen/native/batch_norm.h>
16 #include <ATen/native/Normalization.h>
17 #include <ATen/native/Resize.h>
18 #include <ATen/native/cpu/mixed_data_type.h>
19 #include <c10/util/irange.h>
20 #include <ATen/OpMathType.h>
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_batch_norm_impl_index.h>
27 #include <ATen/ops/_batch_norm_impl_index_backward_native.h>
28 #include <ATen/ops/_batch_norm_impl_index_native.h>
29 #include <ATen/ops/_native_batch_norm_legit_native.h>
30 #include <ATen/ops/_native_batch_norm_legit_no_training.h>
31 #include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
32 #include <ATen/ops/_batch_norm_with_update.h>
33 #include <ATen/ops/_batch_norm_with_update_native.h>
34 #include <ATen/ops/_batch_norm_no_update.h>
35 #include <ATen/ops/_batch_norm_no_update_native.h>
36 #include <ATen/ops/batch_norm_backward_native.h>
37 #include <ATen/ops/alias.h>
38 #include <ATen/ops/batch_norm.h>
39 #include <ATen/ops/batch_norm_native.h>
40 #include <ATen/ops/batch_norm_update_stats_native.h>
41 #include <ATen/ops/cudnn_batch_norm.h>
42 #include <ATen/ops/cudnn_batch_norm_backward.h>
43 #include <ATen/ops/empty.h>
44 #include <ATen/ops/empty_like.h>
45 #include <ATen/ops/instance_norm_native.h>
46 #include <ATen/ops/linalg_vector_norm.h>
47 #include <ATen/ops/mean.h>
48 #include <ATen/ops/miopen_batch_norm.h>
49 #include <ATen/ops/miopen_batch_norm_backward.h>
50 #include <ATen/ops/mul.h>
51 #include <ATen/ops/native_batch_norm.h>
52 #include <ATen/ops/native_batch_norm_backward.h>
53 #include <ATen/ops/native_batch_norm_backward_native.h>
54 #include <ATen/ops/native_batch_norm_native.h>
55 #include <ATen/ops/_native_batch_norm_legit.h>
56 #include <ATen/ops/renorm_native.h>
57 #include <ATen/ops/sum.h>
58 #include <ATen/ops/sqrt.h>
59 #endif
60 
61 #include <c10/core/SymIntArrayRef.h>
62 #include <utility>
63 #include <vector>
64 
65 static const int MIOPEN_DIM_MAX = 5;
66 
67 namespace at::meta {
68 
TORCH_META_FUNC(renorm)69 TORCH_META_FUNC(renorm)(const Tensor& self, const Scalar& p, int64_t dim, const Scalar& maxnorm) {
70   TORCH_CHECK(!p.isComplex(), "renorm: p must be real-valued");
71   TORCH_CHECK(p.toDouble() > 0.0, "renorm: non-positive-norm not supported");
72   TORCH_CHECK(!maxnorm.isComplex(), "renorm: maxnorm must be real-valued");
73   TORCH_CHECK(maxnorm.toDouble() >= 0.0,
74               "renorm: expected maxnorm to be >= 0 but got ", maxnorm.toDouble());
75   const auto ndim = self.dim();
76   TORCH_CHECK(ndim > 1, "renorm: input needs at least 2 dimensions, got ", ndim, " dimensions");
77   set_output_raw_strided(0, self.sizes(), {}, self.options());
78 }
79 
80 }  // namespace at::meta
81 
82 namespace at::native {
83 
84 DEFINE_DISPATCH(batch_norm_cpu_stub);
85 DEFINE_DISPATCH(batch_norm_cpu_collect_stats_stub);
86 DEFINE_DISPATCH(batch_norm_cpu_backward_stub);
87 DEFINE_DISPATCH(renorm_scale_factor_stub);
88 
89 namespace {
check_dims_match_num_input_features(const char * arg_name,const SymInt & expected,const SymInt & actual)90   void check_dims_match_num_input_features(const char* arg_name, const SymInt& expected, const SymInt& actual){
91     TORCH_CHECK(actual == expected,
92              arg_name, " should contain ", expected, " elements not ", actual);
93   }
94 
repeat_if_defined(const Tensor & t,const SymInt & repeat)95   static inline Tensor repeat_if_defined(const Tensor& t, const SymInt& repeat) {
96     if (t.defined()) {
97       return t.repeat_symint(repeat);
98     }
99     return t;
100   }
101 }
102 
103 template<typename T>
104 struct InvStd {
operator ()at::native::InvStd105   T operator()(T var, double epsilon) const {
106     T invstd = 0;
107     if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
108       invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
109     }
110     return invstd;
111   }
112 };
113 
114 template<typename T>
115 struct Var {
operator ()at::native::Var116   T operator()(T var, double epsilon) const {
117     return var;
118   }
119 };
120 
is_contiguous(const Tensor & t)121 static inline bool is_contiguous(const Tensor& t) {
122   return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) || t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
123 }
124 
125 // For some ambiguous cases, it is possible a channels last contiguous Tensor has
126 //   `suggest_memory_format` of Contiguous.
127 // See https://github.com/pytorch/pytorch/issues/63224 for details.
suggest_memory_format_contig(const Tensor & t)128 static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
129   return t.is_contiguous() ?
130     at::MemoryFormat::Contiguous : (t.is_contiguous(at::MemoryFormat::ChannelsLast3d) ?
131     at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast);
132 }
133 
134 template<typename scalar_t, typename param_t>
batch_norm_cpu_transform_input_template(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & save_mean,const Tensor & save_invstd,const Tensor & running_mean,const Tensor & running_var,bool train,double eps,Tensor & output)135 std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
136     const Tensor& input, const Tensor& weight, const Tensor& bias,
137     const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
138     const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
139     bool train, double eps, Tensor& output) {
140 
141   bool all_contiguous = is_contiguous(input)
142     && is_contiguous(output)
143     && (!weight.defined() || weight.is_contiguous())
144     && (!bias.defined() || bias.is_contiguous())
145     && running_mean.is_contiguous()
146     && running_var.is_contiguous();
147 
148   // inference contiguous path
149   if (all_contiguous) {
150     if (input.numel() != 0) {
151       batch_norm_cpu_stub(kCPU, output, input, weight, bias,
152           save_mean, save_invstd, running_mean, running_var, train, eps);
153     }
154     return std::make_tuple(output, save_mean, save_invstd);
155   }
156 
157   const int64_t ndim = input.dim();
158   // Helper to convert 1d tensors to an nd tensor that broadcasts with input
159   // All elements go into the channel dimension
160   DimVector sizes(ndim, 1), strides(ndim, 0);
161   auto as_nd = [&](const Tensor& t) {
162     TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
163     sizes[1] = t.sizes()[0];
164     strides[1] = t.strides()[0];
165     return t.as_strided(sizes, strides);
166   };
167 
168   auto mean = as_nd(train ? save_mean : running_mean);
169   auto invstd = as_nd([&]{
170     if (train) {
171       return save_invstd;
172     } else {
173       return 1 / at::sqrt(running_var + eps);
174     }
175   }());
176   constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
177   const auto dtype = mixed_type ? kFloat : input.scalar_type();
178   auto w = weight.defined() ? as_nd(weight) :
179       at::detail::scalar_tensor_static(1, dtype, kCPU);
180   auto b = bias.defined() ? as_nd(bias) :
181       at::detail::scalar_tensor_static(0, dtype, kCPU);
182 
183   auto iter = TensorIteratorConfig()
184     .add_output(output)
185     .add_input(input)
186     .add_input(mean)
187     .add_input(invstd)
188     .add_input(w)
189     .add_input(b)
190     .check_all_same_dtype(false)
191     .promote_inputs_to_common_dtype(false)
192     .build();
193   cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) -> scalar_t {
194     return ((input - mean) * invstd) * weight + bias;
195   });
196   return std::make_tuple(output, save_mean, save_invstd);
197 }
198 
199 template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
batch_norm_cpu_update_stats_template(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps,Tensor & save_mean,Tensor & save_var_transform)200 std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
201     const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
202     double momentum, double eps, Tensor& save_mean, Tensor& save_var_transform) {
203 
204   using accscalar_t = at::acc_type<scalar_t, false>;
205 
206   int64_t n_input = input.size(1);
207   TORCH_CHECK(input.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", input.sizes());
208   int64_t n = input.numel() / n_input;
209 
210   bool all_contiguous = is_contiguous(input);
211   constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
212   const auto dtype = mixed_type ? kFloat : input.scalar_type();
213 
214   auto save_mean_a = save_mean.accessor<param_t, 1>();
215   auto save_var_transform_a = save_var_transform.accessor<param_t, 1>();
216 
217   auto running_mean_a = conditional_accessor_1d<param_t>(running_mean);
218   auto running_var_a = conditional_accessor_1d<param_t>(running_var);
219 
220   if (all_contiguous) {
221     auto _mean = at::empty({n_input}, input.options().dtype(dtype));
222     auto _var_sum = at::empty({n_input}, input.options().dtype(dtype));
223     auto _mean_a = _mean.accessor<param_t, 1>();
224     auto _var_sum_a = _var_sum.accessor<param_t, 1>();
225     auto momentum_ = static_cast<param_t>(momentum);
226 
227     batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input);
228 
229     parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
230       for (const auto f : c10::irange(b_begin, b_end)) {
231         save_mean_a[f] = _mean_a[f];
232         save_var_transform_a[f] = VarTransform<accscalar_t>{}(_var_sum_a[f] / n, eps);
233 
234         if (running_mean.defined()) {
235           running_mean_a[f] = momentum_ * _mean_a[f] + (1 - momentum_) * running_mean_a[f];
236         }
237         if (running_var.defined()) {
238           accscalar_t unbiased_var = _var_sum_a[f] / (n - 1);
239           running_var_a[f] = momentum_ * unbiased_var + (1 - momentum_) * running_var_a[f];
240         }
241       }
242     });
243 
244     return std::make_tuple(save_mean, save_var_transform);
245   }
246 
247   // non-contiguous path
248   auto channel_stride = input.strides()[1];
249   auto in_data = input.data_ptr<scalar_t>();
250   auto reduce_iter = TensorIteratorConfig()
251       .add_input(input)
252       .resize_outputs(false)
253       .declare_static_shape(input.sizes(), /*squash_dims=*/1)
254       .check_all_same_dtype(false)
255       .promote_inputs_to_common_dtype(false)
256       .build();
257 
258   parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
259     TensorIterator iter(reduce_iter);
260     for (const auto f : c10::irange(b_begin, b_end)) {
261       // compute variance per input
262       iter.unsafe_replace_operand(0, in_data + channel_stride * f);
263       accscalar_t var_sum = 0;
264       auto mean = static_cast<accscalar_t>(save_mean_a[f]);
265       cpu_serial_kernel(iter, [&](const scalar_t i) -> void {
266         var_sum += (i - mean) * (i - mean);
267       });
268       save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
269 
270       // update running averages
271       if (running_mean.defined()) {
272         running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
273       }
274       if (running_var.defined()) {
275         accscalar_t unbiased_var = var_sum / (n - 1);
276         running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
277       }
278     }
279   });
280   return std::make_tuple(save_mean, save_var_transform);
281 }
282 
283 template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
batch_norm_cpu_update_stats_template(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps)284 std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
285     const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
286     double momentum, double eps) {
287   int64_t n_input = input.size(1);
288   const int64_t ndim = input.dim();
289   DimVector reduce_dims(ndim - 1);
290   reduce_dims[0] = 0;
291   for (const auto i : c10::irange(2, ndim)) {
292     reduce_dims[i - 1] = i;
293   }
294 
295   constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
296   const auto dtype = mixed_type ? kFloat : input.scalar_type();
297   Tensor save_mean = is_contiguous(input) ? at::empty({n_input}, input.options().dtype(dtype)) : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype);
298   Tensor save_var_transform = at::empty({n_input}, input.options().dtype(dtype));
299   return batch_norm_cpu_update_stats_template<scalar_t, param_t, VarTransform>(input, running_mean, running_var, momentum, eps, save_mean, save_var_transform);
300 }
301 
302 template<typename scalar_t, typename param_t>
batch_norm_backward_cpu_template(const Tensor & grad_out_,const Tensor & input,const Tensor & weight,const Tensor & running_mean,const Tensor & running_var,const Tensor & save_mean,const Tensor & save_invstd,bool train,double eps,std::array<bool,3> grad_input_mask)303 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
304     const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
305     const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
306     bool train, double eps, std::array<bool,3> grad_input_mask) {
307 
308   using accscalar_t = at::acc_type<scalar_t, false>;
309 
310   constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
311   const auto dtype = mixed_type ? kFloat : input.scalar_type();
312 
313   Tensor grad_input;
314   Tensor grad_weight;
315   Tensor grad_bias;
316   if (grad_input_mask[0]) {
317     grad_input = at::empty_like(input, input.suggest_memory_format());
318   }
319   if (grad_input_mask[1]) {
320     grad_weight = at::empty({input.size(1)}, input.options().dtype(dtype));
321   }
322   if (grad_input_mask[2]) {
323     grad_bias = at::empty({input.size(1)}, input.options().dtype(dtype));
324   }
325 
326   // since we are directly manipulating pointers in contiguous path,
327   // need to make sure input and grad_out have the same memory format.
328   bool all_contiguous = is_contiguous(input)
329       && is_contiguous(grad_out_)
330       && input.suggest_memory_format() == grad_out_.suggest_memory_format();
331 
332   if (all_contiguous) {
333     if (grad_input_mask[0]) {
334       grad_input = at::empty_like(input, suggest_memory_format_contig(input));
335     }
336     batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
337         grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
338     return std::make_tuple(grad_input, grad_weight, grad_bias);
339   }
340 
341   auto weight_a = conditional_accessor_1d<const param_t>(weight);
342   auto grad_weight_a = conditional_accessor_1d<param_t>(grad_weight);
343   auto grad_bias_a = conditional_accessor_1d<param_t>(grad_bias);
344 
345   int64_t n_input = input.size(1);
346   int64_t n = input.numel() / n_input;
347 
348   auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
349   auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
350 
351   auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
352   auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
353 
354   const int64_t ndim = input.dim();
355 
356   // Reduce all dimensions except dim=1
357   DimVector reduce_dims(ndim - 1);
358   reduce_dims[0] = 0;
359   for (const auto i : c10::irange(2, ndim)) {
360     reduce_dims[i - 1] = i;
361   }
362 
363   auto sum = at::sum(grad_out_, /*dim=*/reduce_dims);
364   auto sum_a = sum.accessor<scalar_t, 1>();
365 
366   auto reduce_iter = TensorIteratorConfig()
367       .add_const_input(input)
368       .add_const_input(grad_out_)
369       .resize_outputs(false)
370       .declare_static_shape(input.sizes(), /*squash_dims=*/1)
371       .build();
372 
373   TensorIterator unary_iter;
374   TensorIterator binary_iter;
375   if (grad_input_mask[0]) {
376     unary_iter.build(
377         TensorIteratorConfig()
378         .add_output(grad_input)
379         .add_const_input(train ? input : grad_out_)
380         .resize_outputs(false)
381         .declare_static_shape(input.sizes(), /*squash_dims=*/1));
382 
383     if (train) {
384       binary_iter.build(
385           TensorIteratorConfig()
386           .add_output(grad_input)
387           .add_input(grad_input)
388           .add_const_input(grad_out_)
389           .resize_outputs(false)
390           .declare_static_shape(input.sizes(), /*squash_dims=*/1));
391     }
392   }
393 
394   auto in_channel_stride = input.strides()[1];
395   auto in_data = input.const_data_ptr<scalar_t>();
396   auto grad_in_channel_stride = grad_input_mask[0] ? grad_input.strides()[1] : 0;
397   auto grad_in_data = grad_input_mask[0] ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
398   auto grad_out_channel_stride = grad_out_.strides()[1];
399   auto grad_out_data = grad_out_.const_data_ptr<scalar_t>();
400 
401   parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
402       TensorIterator reduce_iter_local(reduce_iter);
403       TensorIterator unary_iter_local(unary_iter);
404       TensorIterator binary_iter_local(binary_iter);
405 
406       for (const auto f : c10::irange(b_begin, b_end)) {
407         param_t w = weight.defined() ? weight_a[f] : param_t(1);
408 
409         param_t mean{}, invstd{};
410         if (train) {
411           mean = save_mean_a[f];
412           invstd = save_invstd_a[f];
413         } else {
414           mean = running_mean_a[f];
415           invstd = 1 / std::sqrt(running_var_a[f] + eps);
416         }
417 
418         // dot product of the Q(X) and gradOutput
419         accscalar_t dotp = 0;
420         reduce_iter_local.unsafe_replace_operand(
421             0, const_cast<scalar_t*>(in_data + f * in_channel_stride));
422         reduce_iter_local.unsafe_replace_operand(
423             1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
424 
425         cpu_serial_kernel(reduce_iter_local, [&](const scalar_t i, const scalar_t go) -> void {
426           dotp += (i - mean) * go;
427         });
428 
429         if (grad_input_mask[0]) {
430           if (train) {
431             // when in training mode
432             // Q(X) = X - E[x] ; i.e. input centered to zero mean
433             // Y = Q(X) / sigma    ; i.e. BN output before weight and bias
434             // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w
435 
436             // projection of gradOutput on to output scaled by std
437             scalar_t k = (scalar_t) dotp * invstd * invstd / n;
438             {
439               unary_iter_local.unsafe_replace_operand(
440                   0, grad_in_data + f * grad_in_channel_stride);
441               unary_iter_local.unsafe_replace_operand(
442                   1, const_cast<scalar_t*>(in_data + f * in_channel_stride));
443               cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
444                 return (i - mean) * k;
445               });
446             }
447 
448             scalar_t grad_mean = sum_a[f] / n;
449             {
450               auto gI_data = grad_in_data + f * grad_in_channel_stride;
451               binary_iter_local.unsafe_replace_operand(0, gI_data);
452               binary_iter_local.unsafe_replace_operand(1, gI_data);
453               binary_iter_local.unsafe_replace_operand(
454                   2, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
455               cpu_serial_kernel(binary_iter_local, [&](scalar_t gi, scalar_t go) -> scalar_t {
456                 return (go - grad_mean - gi) * invstd * w;
457               });
458             }
459           } else {
460             // when in evaluation mode
461             // Q(X) = X - running_mean  ; i.e. input centered to zero mean
462             // Y = Q(X) / running_std    ; i.e. BN output before weight and bias
463             // dL/dX = w / running_std
464             {
465               unary_iter_local.unsafe_replace_operand(
466                   0, grad_in_data + f * grad_in_channel_stride);
467               unary_iter_local.unsafe_replace_operand(
468                   1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
469               cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
470                 return i * invstd * w;
471               });
472             }
473           }
474         }
475         if (grad_input_mask[1]) {
476           grad_weight_a[f] = dotp * invstd;
477         }
478 
479         if (grad_input_mask[2]) {
480           grad_bias_a[f] = sum_a[f];
481         }
482       }
483     });
484   return std::make_tuple(grad_input, grad_weight, grad_bias);
485 }
486 
_select_batch_norm_backend(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & running_mean,const Tensor & running_var,bool training,double eps)487 BatchNormBackend _select_batch_norm_backend(
488     const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
489     const Tensor& running_var, bool training, double eps) {
490 
491   auto& ctx = at::globalContext();
492   bool cudnn_enabled = ctx.userEnabledCuDNN();
493 
494   if (
495       input.is_cuda()
496       && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
497       && (input.scalar_type() != at::kHalf
498         || weight.scalar_type() == at::kFloat)
499       && weight.defined() && bias.defined()
500       && ((running_mean.defined() && running_var.defined())
501         || (!running_mean.defined() && !running_var.defined() && training))
502       && (input.dim() >= 3)
503       && ((input.sym_size(0) <= 880801 && training) // spatial, training
504           ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
505       && detail::getCUDAHooks().compiledWithCuDNN()
506       && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
507       && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
508       && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
509   ) {
510     return BatchNormBackend::Cudnn;
511   }
512 
513   if (
514       input.is_cuda()
515       && input.dim() <= MIOPEN_DIM_MAX
516       && input.scalar_type() != at::kDouble
517       && input.scalar_type() != at::kBFloat16
518       && (weight.scalar_type() != at::kHalf)
519       && weight.defined() && bias.defined()
520       && ((running_mean.defined() && running_var.defined())
521         || (!running_mean.defined() && !running_var.defined() && training))
522       && (input.dim() >= 3)
523       && detail::getCUDAHooks().compiledWithMIOpen()
524       && cudnn_enabled
525       && input.suggest_memory_format() != MemoryFormat::ChannelsLast
526       && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
527   ) {
528     return BatchNormBackend::Miopen;
529   }
530 
531   return BatchNormBackend::Native;
532 }
533 
534 
535 // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
536 // of backends, while enabling it to keep the information about the used backend, so that it can
537 // use its corresponding backward implementation.
538 // XXX: The indices of backends need to be kept synchronized between this function and its _backward.
539 // TODO: remove cudnn_enabled arg
_batch_norm_impl_index(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double momentum,double eps,bool cudnn_enabled)540 std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
541     const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
542     bool training, double momentum, double eps, bool cudnn_enabled) {
543   // See [Note: hacky wrapper removal for optional tensor]
544   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
545   const Tensor& weight = *weight_maybe_owned;
546   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
547   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
548   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
549 
550   auto num_features = input.sym_sizes()[1];
551 
552   if (input.sym_numel() == 0) {
553     Tensor reserve = at::empty({0}, input.options().dtype(kByte));
554     auto options = input.options().dtype(
555         at::toAccumulateType(input.scalar_type(), input.device().type()));
556     auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
557     auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options);
558 
559     // don't return view of input, don't return empty tensor because it will break gradient chain
560     auto out = input.clone();
561     if (weight.defined()) out = out * weight[0];
562     if (bias.defined()) out = out + bias[0];
563     return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
564         out, save_mean, save_invstd, reserve, 0);
565   }
566 
567   if (running_mean.defined()) {
568     check_dims_match_num_input_features("running_mean", num_features, running_mean.sym_numel());
569   } else if (!training) {
570     AT_ERROR("running_mean must be defined in evaluation mode");
571   }
572   if (running_var.defined()) {
573     check_dims_match_num_input_features("running_var", num_features, running_var.sym_numel());
574   } else if (!training) {
575     AT_ERROR("running_var must be defined in evaluation mode");
576   }
577   if (weight.defined()) {
578     check_dims_match_num_input_features("weight", num_features, weight.sym_numel());
579   }
580   if (bias.defined()) {
581     check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
582   }
583 
584   BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
585 
586   if (backend == BatchNormBackend::Cudnn) {
587     auto input_c = input.contiguous(input.suggest_memory_format());
588     auto weight_c = weight.contiguous();
589     auto bias_c = bias.contiguous();
590     auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
591     auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
592 
593     auto [output, save_mean, save_var, reserve] =
594         at::cudnn_batch_norm(input_c, weight_c, bias_c, rmean_c, rvar_c,
595                              training, momentum, eps);
596 
597     return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
598         output, save_mean, save_var, reserve, 1);
599   }
600 
601   Tensor reserve = at::empty({0}, input.options().dtype(kByte));
602 
603   if (backend == BatchNormBackend::Miopen) {
604     return std::tuple_cat(
605              at::miopen_batch_norm(
606                input.contiguous(), weight.contiguous(), bias.contiguous(),
607                running_mean.defined() ? running_mean.contiguous() : running_mean,
608                running_var.defined() ? running_var.contiguous() : running_var,
609                training, momentum, eps),
610              std::tuple<Tensor>(reserve),
611              std::make_tuple(2));
612   }
613 
614   return std::tuple_cat(
615            at::native_batch_norm(
616              input, weight, bias, running_mean, running_var, training, momentum, eps),
617            std::tuple<Tensor>(reserve),
618            std::make_tuple(0));
619 }
620 
_batch_norm_impl_index_backward(int64_t impl_index,const Tensor & input,const Tensor & grad_output,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_transform_opt,bool train,double epsilon,std::array<bool,3> output_mask,const Tensor & reservedSpace)621 std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
622     int64_t impl_index,
623     const Tensor& input, const Tensor& grad_output, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, const std::optional<Tensor>& save_mean_opt /* optional */, const std::optional<Tensor>& save_var_transform_opt /* optional */,
624     bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
625   // See [Note: hacky wrapper removal for optional tensor]
626   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
627   const Tensor& weight = *weight_maybe_owned;
628   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
629   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
630   const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
631   const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();});
632 
633   if (input.numel() == 0) {
634     std::vector<int64_t> dims(input.dim() - 1);
635     dims[0] = 0;
636     std::iota(dims.begin() + 1, dims.end(), 2);
637 
638     // don't return empty tensor because it will break gradient chain
639     Tensor grad_input;
640     Tensor grad_weight;
641     Tensor grad_bias;
642     if (output_mask[2]) {
643       grad_bias = grad_output.sum(dims);
644     }
645     if (output_mask[1]) {
646       grad_weight = (grad_output * input).sum(dims);
647     }
648     if (output_mask[0] && weight.defined()) {
649       grad_input = grad_output * weight[0];
650     }
651     return std::make_tuple(grad_input, grad_weight, grad_bias);
652   }
653 
654   // backward in inference mode is not supported in cudnn, fallback to native
655   if (impl_index == 0 || (!train)) {
656     return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
657   } else if (impl_index == 1) {
658     // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
659     // format conversion is done inside cudnn_batch_norm_backward instead
660     return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
661   } else if (impl_index == 2) {
662     return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
663   }
664   TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
665 }
666 
667 // TODO: remove cudnn_enabled arg
batch_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double momentum,double eps,bool cudnn_enabled)668 Tensor batch_norm(
669     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
670     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
671     bool training, double momentum, double eps, bool cudnn_enabled) {
672   const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
673   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
674   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
675   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
676   return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
677                                                 training, momentum, eps, cudnn_enabled));
678   // TODO: switch to the new stack after the 2 week FC window
679   // if (training) {
680   //   BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
681   //   if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
682   //     auto input_c = input;
683   //     if (backend == BatchNormBackend::Cudnn) {
684   //         input_c = input.contiguous(input.suggest_memory_format());
685   //     } else {
686   //         input_c = input.contiguous();
687   //     }
688   //     auto weight_c = weight.contiguous();
689   //     auto bias_c = bias.contiguous();
690   //     auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
691   //     auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
692   //     return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
693   //                                                   const_cast<Tensor&>(rvar_c), momentum, eps));
694   //   } else {
695   //     return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
696   //                                                   const_cast<Tensor&>(running_var), momentum, eps));
697   //   }
698   // } else {
699   //   return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
700   //                                               momentum, eps));
701   // }
702 }
703 
instance_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool use_input_stats,double momentum,double eps,bool cudnn_enabled)704 Tensor instance_norm(
705     const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
706     bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
707   // See [Note: hacky wrapper removal for optional tensor]
708   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
709   const Tensor& weight = *weight_maybe_owned;
710   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
711   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
712   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
713 
714  TORCH_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()),
715            "Expected running_mean and running_var to be defined when use_input_stats is false");
716   std::vector<SymInt> shape = input.sym_sizes().vec();
717   SymInt b = input.sym_size(0);
718   SymInt c = input.sym_size(1);
719   shape[1] = b * c;
720   shape[0] = SymInt(1);
721 
722   Tensor weight_ = repeat_if_defined(weight, b);
723   Tensor bias_ = repeat_if_defined(bias, b);
724   Tensor running_mean_ = repeat_if_defined(running_mean, b);
725   Tensor running_var_ = repeat_if_defined(running_var, b);
726 
727   auto input_reshaped = input.contiguous().view_symint(shape);
728   auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_,
729                             use_input_stats, momentum, eps, cudnn_enabled);
730 
731   // we alias running_mean and running_var because they are const but we want to modify their data
732   if (running_mean.defined()) {
733     at::alias(running_mean).copy_(running_mean_.view_symint({ b, c }).mean(0, false));
734   }
735   if (running_var.defined()) {
736     at::alias(running_var).copy_(running_var_.view_symint({ std::move(b), std::move(c) }).mean(0, false));
737   }
738 
739   return out.view_symint(input.sym_sizes());
740 }
741 
batch_norm_update_stats_cpu(const Tensor & self,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum)742 std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
743         const Tensor& self, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, double momentum) {
744   // See [Note: hacky wrapper removal for optional tensor]
745   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
746   const Tensor& running_mean = *running_mean_maybe_owned;
747   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
748 
749   const bool mixed_type = is_mixed_type(self, running_mean, running_var);
750   return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_update_stats_cpu", [&] {
751     using opmath_t = at::opmath_type<scalar_t>;
752     if (mixed_type) {
753       check_mixed_data_type(self, running_mean, running_var);
754       return batch_norm_cpu_update_stats_template<scalar_t, opmath_t, Var>(self, running_mean, running_var, momentum, 0);
755     } else {
756       return batch_norm_cpu_update_stats_template<scalar_t, scalar_t, Var>(self, running_mean, running_var, momentum, 0);
757     }
758   });
759 }
760 
batch_norm_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)761 std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
762                                                   bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
763   // See [Note: hacky wrapper removal for optional tensor]
764   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
765   const Tensor& weight = *weight_maybe_owned;
766   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
767   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
768   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
769 
770   checkBackend("batch_norm_cpu_out", {self, weight, bias, running_mean, running_var}, Backend::CPU);
771   // Resize out
772   at::native::resize_output(out, self.sizes());
773 
774   const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var);
775   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm", [&] {
776     using opmath_t = at::opmath_type<scalar_t>;
777     if (mixed_type) {
778       check_mixed_data_type(self, weight, bias, running_mean, running_var);
779       if (!train) {
780         return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
781       } else {
782         // Resize save_mean and save_var
783         at::native::resize_output(save_mean, {self.size(1)});
784         at::native::resize_output(save_var, {self.size(1)});
785         auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, opmath_t, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
786         return batch_norm_cpu_transform_input_template<scalar_t, opmath_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
787       }
788     } else {
789       if (!train) {
790         return batch_norm_cpu_transform_input_template<scalar_t, scalar_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out);
791       } else {
792         // Resize save_mean and save_var
793         at::native::resize_output(save_mean, {self.size(1)});
794         at::native::resize_output(save_var, {self.size(1)});
795         auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, scalar_t, InvStd>(self, running_mean, running_var, momentum, eps, save_mean, save_var);
796         return batch_norm_cpu_transform_input_template<scalar_t, scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out);
797       }
798     }
799   });
800 
801   return std::tuple<Tensor& ,Tensor&, Tensor&>(out, save_mean, save_var);
802 }
803 
batch_norm_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)804 std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
805                                                   bool train, double momentum, double eps) {
806   // See [Note: hacky wrapper removal for optional tensor]
807   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
808   const Tensor& weight = *weight_maybe_owned;
809   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
810   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
811   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
812 
813   checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU);
814 
815   // Prepare output tensor
816   const bool all_contiguous = is_contiguous(self)
817     && (!weight.defined() || weight.is_contiguous())
818     && (!bias.defined() || bias.is_contiguous())
819     && running_mean.is_contiguous()
820     && running_var.is_contiguous();
821   Tensor output = at::empty_like(self, all_contiguous ? suggest_memory_format_contig(self) : self.suggest_memory_format());
822 
823   // Prepare save_mean and save_var
824   Tensor save_var;
825   Tensor save_mean;
826   const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var);
827   const int64_t ndim = self.dim();
828   DimVector reduce_dims(ndim - 1);
829   reduce_dims[0] = 0;
830   for (const auto i : c10::irange(2, ndim)) {
831     reduce_dims[i - 1] = i;
832   }
833   if (mixed_type) {
834     if (!train) {
835       save_mean = at::empty({0}, self.options().dtype(kFloat));
836       save_var = at::empty({0}, self.options().dtype(kFloat));
837     } else {
838       save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options().dtype(kFloat)) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false, kFloat);
839       save_var = at::empty({self.size(1)}, self.options().dtype(kFloat));
840     }
841   } else {
842     if (!train) {
843       save_mean = at::empty({0}, self.options());
844       save_var = at::empty({0}, self.options());
845     } else {
846       save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options()) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false);
847       save_var = at::empty({self.size(1)}, self.options());
848     }
849   }
850   return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
851 }
852 
_batch_norm_with_update_cpu(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)853 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
854     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
855     Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
856   auto [output, save_mean, save_var] =
857     batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
858   Tensor reserve = at::empty({0}, input.options().dtype(kByte));
859   return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
860 }
861 
_batch_norm_with_update_cpu_out(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var,Tensor & reserve)862 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
863     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
864     Tensor& running_mean, Tensor& running_var, double momentum, double eps,
865     Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
866   std::tie(out, save_mean, save_var) =
867     batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
868   return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
869 }
870 
871 
_batch_norm_no_update(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double momentum,double eps)872 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
873     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
874     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
875     double momentum, double eps) {
876   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
877   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
878   auto [output, save_mean, save_var] =
879     batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
880   Tensor reserve = at::empty({0}, input.options().dtype(kByte));
881   return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
882 }
883 
_batch_norm_legit_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)884 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
885     const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
886     Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
887   return batch_norm_cpu(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
888 }
889 
_batch_norm_legit_no_stats_cpu(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)890 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cpu(
891     const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
892     bool train, double momentum, double eps) {
893   return batch_norm_cpu(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
894 }
_batch_norm_legit_no_training(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & running_mean,const Tensor & running_var,double momentum,double eps)895 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_training(
896     const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
897     const Tensor& running_mean, const Tensor& running_var, double momentum, double eps) {
898   return at::_native_batch_norm_legit(self, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*train=*/false, momentum, eps);
899 }
900 
901 
_batch_norm_legit_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)902 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
903   return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps, out, save_mean, save_var);
904 }
905 
906 
_batch_norm_legit_no_stats_cpu_out(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps,Tensor & out,Tensor & save_mean,Tensor & save_var)907 std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
908   return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
909 }
910 
_new_batch_norm_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)911 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
912     const Tensor& grad_output, const Tensor& input, const Tensor& weight,
913     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
914     const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
915     bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
916   return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
917 }
918 
batch_norm_backward_cpu(const Tensor & grad_out,const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)919 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
920                                                            bool train, double eps, std::array<bool,3> grad_input_mask) {
921   // See [Note: hacky wrapper removal for optional tensor]
922   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
923   const Tensor& weight = *weight_maybe_owned;
924   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
925   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
926   const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
927   const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();});
928 
929   const bool mixed_type = is_mixed_type(self, weight, running_mean, running_var, save_mean, save_invstd);
930   return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_backward_cpu", [&] {
931     using opmath_t = at::opmath_type<scalar_t>;
932     if (mixed_type) {
933       check_mixed_data_type(self, weight, running_mean, running_var, save_mean, save_invstd);
934       return batch_norm_backward_cpu_template<scalar_t, opmath_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
935     } else {
936       return batch_norm_backward_cpu_template<scalar_t, scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
937     }
938   });
939 }
940 
TORCH_IMPL_FUNC(renorm_out)941 TORCH_IMPL_FUNC(renorm_out)(const Tensor& self, const Scalar& p, int64_t dim,
942                             const Scalar& maxnorm, const Tensor& out) {
943   auto self_sizes = self.sizes();
944   dim = c10::maybe_wrap_dim(dim, self_sizes.size());
945 
946   DimVector reduce_dims(self_sizes.size());
947   std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
948   reduce_dims.erase(reduce_dims.begin() + dim);
949 
950   // For cuda half, calculate norm in float precision then cast
951   // normalization factor to half
952   auto dtype = self.scalar_type();
953   auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
954   Tensor norm;
955   if (acc_type != dtype) {
956     norm = at::linalg_vector_norm(self, p.toDouble(), reduce_dims,
957                                   /*keepdim=*/true, /*dtype=*/acc_type);
958   } else {
959     norm = at::linalg_vector_norm(self, p.toDouble(), reduce_dims,
960                                   /*keepdim=*/true);
961   }
962 
963   auto factor = (acc_type == c10::toRealValueType(dtype)) ?
964       norm : at::empty(norm.sizes(), self.options());
965   auto iter = TensorIteratorConfig()
966       .add_output(factor)
967       .add_input(norm)
968       .set_check_mem_overlap(false)
969       .cast_common_dtype_to_outputs(true)
970       .build();
971 
972   renorm_scale_factor_stub(iter.device_type(), iter, maxnorm.toDouble());
973   at::mul_outf(self, factor, const_cast<Tensor&>(out));
974 }
975 
976 } // at::native
977