xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/group_norm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/cpu/mixed_data_type.h>
6 #include <c10/util/accumulate.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/empty_like_native.h>
14 #include <ATen/ops/group_norm_native.h>
15 #include <ATen/ops/native_batch_norm.h>
16 #include <ATen/ops/native_group_norm.h>
17 #include <ATen/ops/native_group_norm_backward_native.h>
18 #include <ATen/ops/native_group_norm_native.h>
19 #endif
20 
21 #include <array>
22 #include <functional>
23 #include <tuple>
24 #include <vector>
25 
26 namespace at::native {
27 
28 template <typename T>
check_group_norm_inputs(const Tensor & input,const Tensor & weight,const Tensor & bias,T C,int64_t num_groups)29 void check_group_norm_inputs(
30     const Tensor& input,
31     const Tensor& weight,
32     const Tensor& bias,
33     T C,
34     int64_t num_groups) {
35   TORCH_CHECK(
36       num_groups > 0,
37       "Expected num groups to be greater than 0, got ", num_groups);
38   TORCH_CHECK(
39       C % num_groups == 0,
40       "Expected number of channels in input to be divisible by ",
41       "num_groups, but got input of shape ",
42       input.sizes(),
43       " and "
44       "num_groups=",
45       num_groups);
46   TORCH_CHECK(
47       !weight.defined() || (weight.dim() == 1 && at::symint::numel<T>(weight) == C),
48       "Expected weight to be a vector of size equal to the number of ",
49       "channels in input, but got weight of shape ",
50       weight.sizes(),
51       " and input of shape ",
52       input.sizes());
53   TORCH_CHECK(
54       !bias.defined() || (bias.dim() == 1 && at::symint::numel<T>(bias) == C),
55       "Expected bias to be a vector of size equal to the number of ",
56       "channels in input, but got bias of shape ",
57       weight.sizes(),
58       " and input of shape ",
59       input.sizes());
60 }
61 
native_group_norm(const Tensor & X,const std::optional<Tensor> & gamma_opt,const std::optional<Tensor> & beta_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)62 std::tuple<Tensor, Tensor, Tensor> native_group_norm(
63     const Tensor& X,
64     const std::optional<Tensor>& gamma_opt /* optional */,
65     const std::optional<Tensor>& beta_opt /* optional */,
66     int64_t N,
67     int64_t C,
68     int64_t HxW,
69     int64_t group,
70     double eps) {
71   // See [Note: hacky wrapper removal for optional tensor]
72   c10::MaybeOwned<Tensor> gamma_maybe_owned =
73       at::borrow_from_optional_tensor(gamma_opt);
74   const Tensor& gamma = *gamma_maybe_owned;
75   const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); });
76 
77   // repeated check so expanded weights can call native_group_norm directly but
78   // save mean and variance from forward
79   check_group_norm_inputs(X, gamma, beta, C, group);
80   auto memory_format = X.device().is_cpu() ?
81       X.suggest_memory_format() : at::MemoryFormat::Contiguous;
82 
83   TORCH_CHECK(X.is_contiguous(memory_format));
84 
85   bool mixed_type = is_mixed_type(X, gamma, beta);
86   if (mixed_type) {
87     check_mixed_data_type(X, gamma, beta);
88   }
89 
90   Tensor Y = at::native::empty_like(
91       X,
92       std::nullopt /* dtype */,
93       std::nullopt /* layout */,
94       std::nullopt /* device */,
95       std::nullopt /* pin_memory */,
96       memory_format);
97   const auto dtype = param_scalar_type(X, mixed_type);
98   Tensor mean = at::empty({N, group}, X.options().dtype(dtype));
99   Tensor rstd = at::empty({N, group}, X.options().dtype(dtype));
100   GroupNormKernel(
101       X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
102   return std::make_tuple(Y, mean, rstd);
103 }
104 
native_group_norm_backward(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & gamma_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,std::array<bool,3> grad_input_mask)105 std::tuple<Tensor, Tensor, Tensor> native_group_norm_backward(
106     const Tensor& dY,
107     const Tensor& X,
108     const Tensor& mean,
109     const Tensor& rstd,
110     const std::optional<Tensor>& gamma_opt,
111     int64_t N,
112     int64_t C,
113     int64_t HxW,
114     int64_t group,
115     std::array<bool, 3> grad_input_mask) {
116   // See [Note: hacky wrapper removal for optional tensor]
117   c10::MaybeOwned<Tensor> gamma_maybe_owned =
118       at::borrow_from_optional_tensor(gamma_opt);
119   const Tensor& gamma = *gamma_maybe_owned;
120   TORCH_CHECK(
121       X.scalar_type() == dY.scalar_type(),
122       "Expected scalar types of X and dY are same.");
123   bool mixed_type = is_mixed_type(X, mean, rstd);
124   if (mixed_type) {
125     check_mixed_data_type(X, mean, rstd);
126   }
127   auto memory_format = X.device().is_cpu() ?
128       X.suggest_memory_format() : at::MemoryFormat::Contiguous;
129 
130   Tensor dX;
131   Tensor dgamma;
132   Tensor dbeta;
133   if (grad_input_mask[0]) {
134     dX = at::native::empty_like(
135         X,
136         std::nullopt /* dtype */,
137         std::nullopt /* layout */,
138         std::nullopt /* device */,
139         std::nullopt /* pin_memory */,
140         memory_format);
141   }
142   if (grad_input_mask[1]) {
143     dgamma = at::native::empty_like(
144         gamma,
145         std::nullopt /* dtype */,
146         std::nullopt /* layout */,
147         std::nullopt /* device */,
148         std::nullopt /* pin_memory */,
149         at::MemoryFormat::Contiguous);
150   }
151   if (grad_input_mask[2]) {
152     dbeta = at::native::empty_like(
153         gamma,
154         std::nullopt /* dtype */,
155         std::nullopt /* layout */,
156         std::nullopt /* device */,
157         std::nullopt /* pin_memory */,
158         at::MemoryFormat::Contiguous);
159   }
160   GroupNormBackwardKernel(
161       X.device().type(),
162       dY,
163       X,
164       mean,
165       rstd,
166       gamma,
167       N,
168       C,
169       HxW,
170       group,
171       dX,
172       dgamma,
173       dbeta);
174   return std::make_tuple(dX, dgamma, dbeta);
175 }
176 
group_norm(const Tensor & input,int64_t num_groups,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)177 Tensor group_norm(
178     const Tensor& input,
179     int64_t num_groups,
180     const std::optional<Tensor>& weight_opt /* optional */,
181     const std::optional<Tensor>& bias_opt /* optional */,
182     double eps,
183     bool /* cudnn_enabled, deprecated */) {
184   // See [Note: hacky wrapper removal for optional tensor]
185   c10::MaybeOwned<Tensor> weight_maybe_owned =
186       at::borrow_from_optional_tensor(weight_opt);
187   const Tensor& weight = *weight_maybe_owned;
188   const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
189 
190   const auto N = input.sym_size(0);
191   const auto C = input.sym_size(1);
192   check_group_norm_inputs(input, weight, bias, C, num_groups);
193 
194   const auto input_shape = input.sym_sizes();
195   const auto HxW =
196       c10::multiply_integers(input_shape.slice(2));
197 
198   const Tensor kEmpty;
199   auto memory_format = input.suggest_memory_format();
200   const auto& X = input.device().is_cpu() ? input.contiguous(memory_format) : input.contiguous();
201   const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty;
202   const auto& beta = bias.defined() ? bias.contiguous() : kEmpty;
203   TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C);
204   TORCH_CHECK(!beta.defined() || beta.sym_numel() == C);
205   return std::get<0>(
206       at::native_group_norm_symint(X, gamma, beta, N, C, HxW, num_groups, eps));
207 }
208 
209 DEFINE_DISPATCH(GroupNormKernel);
210 DEFINE_DISPATCH(GroupNormBackwardKernel);
211 
212 // Ported from pytorch/xla repo
math_group_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)213 std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
214     const Tensor& input,
215     const std::optional<Tensor>& weight_opt,
216     const std::optional<Tensor>& bias_opt,
217     int64_t N,
218     int64_t C,
219     int64_t HxW,
220     int64_t group,
221     double eps) {
222   // See [Note: hacky wrapper removal for optional tensor]
223   c10::MaybeOwned<Tensor> weight_maybe_owned =
224       at::borrow_from_optional_tensor(weight_opt);
225   const Tensor& weight = *weight_maybe_owned;
226   const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
227 
228   auto input_shape = input.sizes();
229   at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1});
230   auto outputs = at::native_batch_norm(
231       input_reshaped,
232       /*weight=*/{},
233       /*bias=*/{},
234       /*running_mean=*/{},
235       /*running_var=*/{},
236       /*training=*/true,
237       /*momentum=*/0,
238       eps);
239   at::Tensor out = std::get<0>(outputs);
240   out = out.view(input_shape);
241   std::vector<int64_t> affine_param_shape(input.dim(), 1);
242   affine_param_shape[1] = C;
243   if (weight.defined() && bias.defined()) {
244     out = bias.view(affine_param_shape)
245               .addcmul(out, weight.view(affine_param_shape), 1);
246   } else if (weight.defined()) {
247     out = out.mul(weight.view(affine_param_shape));
248   } else if (bias.defined()) {
249     out = out.add(bias.view(affine_param_shape));
250   }
251   // convert mean/std to have the same dtype as input.
252   // This follows the same behavior as the CPU and CUDA kernels.
253   at::Tensor mean = std::get<1>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
254   at::Tensor rstd = std::get<2>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
255   return std::make_tuple(out, mean, rstd);
256 }
257 } // namespace at::native
258