1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/cpu/mixed_data_type.h>
6 #include <c10/util/accumulate.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/empty_like_native.h>
14 #include <ATen/ops/group_norm_native.h>
15 #include <ATen/ops/native_batch_norm.h>
16 #include <ATen/ops/native_group_norm.h>
17 #include <ATen/ops/native_group_norm_backward_native.h>
18 #include <ATen/ops/native_group_norm_native.h>
19 #endif
20
21 #include <array>
22 #include <functional>
23 #include <tuple>
24 #include <vector>
25
26 namespace at::native {
27
28 template <typename T>
check_group_norm_inputs(const Tensor & input,const Tensor & weight,const Tensor & bias,T C,int64_t num_groups)29 void check_group_norm_inputs(
30 const Tensor& input,
31 const Tensor& weight,
32 const Tensor& bias,
33 T C,
34 int64_t num_groups) {
35 TORCH_CHECK(
36 num_groups > 0,
37 "Expected num groups to be greater than 0, got ", num_groups);
38 TORCH_CHECK(
39 C % num_groups == 0,
40 "Expected number of channels in input to be divisible by ",
41 "num_groups, but got input of shape ",
42 input.sizes(),
43 " and "
44 "num_groups=",
45 num_groups);
46 TORCH_CHECK(
47 !weight.defined() || (weight.dim() == 1 && at::symint::numel<T>(weight) == C),
48 "Expected weight to be a vector of size equal to the number of ",
49 "channels in input, but got weight of shape ",
50 weight.sizes(),
51 " and input of shape ",
52 input.sizes());
53 TORCH_CHECK(
54 !bias.defined() || (bias.dim() == 1 && at::symint::numel<T>(bias) == C),
55 "Expected bias to be a vector of size equal to the number of ",
56 "channels in input, but got bias of shape ",
57 weight.sizes(),
58 " and input of shape ",
59 input.sizes());
60 }
61
native_group_norm(const Tensor & X,const std::optional<Tensor> & gamma_opt,const std::optional<Tensor> & beta_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)62 std::tuple<Tensor, Tensor, Tensor> native_group_norm(
63 const Tensor& X,
64 const std::optional<Tensor>& gamma_opt /* optional */,
65 const std::optional<Tensor>& beta_opt /* optional */,
66 int64_t N,
67 int64_t C,
68 int64_t HxW,
69 int64_t group,
70 double eps) {
71 // See [Note: hacky wrapper removal for optional tensor]
72 c10::MaybeOwned<Tensor> gamma_maybe_owned =
73 at::borrow_from_optional_tensor(gamma_opt);
74 const Tensor& gamma = *gamma_maybe_owned;
75 const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); });
76
77 // repeated check so expanded weights can call native_group_norm directly but
78 // save mean and variance from forward
79 check_group_norm_inputs(X, gamma, beta, C, group);
80 auto memory_format = X.device().is_cpu() ?
81 X.suggest_memory_format() : at::MemoryFormat::Contiguous;
82
83 TORCH_CHECK(X.is_contiguous(memory_format));
84
85 bool mixed_type = is_mixed_type(X, gamma, beta);
86 if (mixed_type) {
87 check_mixed_data_type(X, gamma, beta);
88 }
89
90 Tensor Y = at::native::empty_like(
91 X,
92 std::nullopt /* dtype */,
93 std::nullopt /* layout */,
94 std::nullopt /* device */,
95 std::nullopt /* pin_memory */,
96 memory_format);
97 const auto dtype = param_scalar_type(X, mixed_type);
98 Tensor mean = at::empty({N, group}, X.options().dtype(dtype));
99 Tensor rstd = at::empty({N, group}, X.options().dtype(dtype));
100 GroupNormKernel(
101 X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
102 return std::make_tuple(Y, mean, rstd);
103 }
104
native_group_norm_backward(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & gamma_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,std::array<bool,3> grad_input_mask)105 std::tuple<Tensor, Tensor, Tensor> native_group_norm_backward(
106 const Tensor& dY,
107 const Tensor& X,
108 const Tensor& mean,
109 const Tensor& rstd,
110 const std::optional<Tensor>& gamma_opt,
111 int64_t N,
112 int64_t C,
113 int64_t HxW,
114 int64_t group,
115 std::array<bool, 3> grad_input_mask) {
116 // See [Note: hacky wrapper removal for optional tensor]
117 c10::MaybeOwned<Tensor> gamma_maybe_owned =
118 at::borrow_from_optional_tensor(gamma_opt);
119 const Tensor& gamma = *gamma_maybe_owned;
120 TORCH_CHECK(
121 X.scalar_type() == dY.scalar_type(),
122 "Expected scalar types of X and dY are same.");
123 bool mixed_type = is_mixed_type(X, mean, rstd);
124 if (mixed_type) {
125 check_mixed_data_type(X, mean, rstd);
126 }
127 auto memory_format = X.device().is_cpu() ?
128 X.suggest_memory_format() : at::MemoryFormat::Contiguous;
129
130 Tensor dX;
131 Tensor dgamma;
132 Tensor dbeta;
133 if (grad_input_mask[0]) {
134 dX = at::native::empty_like(
135 X,
136 std::nullopt /* dtype */,
137 std::nullopt /* layout */,
138 std::nullopt /* device */,
139 std::nullopt /* pin_memory */,
140 memory_format);
141 }
142 if (grad_input_mask[1]) {
143 dgamma = at::native::empty_like(
144 gamma,
145 std::nullopt /* dtype */,
146 std::nullopt /* layout */,
147 std::nullopt /* device */,
148 std::nullopt /* pin_memory */,
149 at::MemoryFormat::Contiguous);
150 }
151 if (grad_input_mask[2]) {
152 dbeta = at::native::empty_like(
153 gamma,
154 std::nullopt /* dtype */,
155 std::nullopt /* layout */,
156 std::nullopt /* device */,
157 std::nullopt /* pin_memory */,
158 at::MemoryFormat::Contiguous);
159 }
160 GroupNormBackwardKernel(
161 X.device().type(),
162 dY,
163 X,
164 mean,
165 rstd,
166 gamma,
167 N,
168 C,
169 HxW,
170 group,
171 dX,
172 dgamma,
173 dbeta);
174 return std::make_tuple(dX, dgamma, dbeta);
175 }
176
group_norm(const Tensor & input,int64_t num_groups,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)177 Tensor group_norm(
178 const Tensor& input,
179 int64_t num_groups,
180 const std::optional<Tensor>& weight_opt /* optional */,
181 const std::optional<Tensor>& bias_opt /* optional */,
182 double eps,
183 bool /* cudnn_enabled, deprecated */) {
184 // See [Note: hacky wrapper removal for optional tensor]
185 c10::MaybeOwned<Tensor> weight_maybe_owned =
186 at::borrow_from_optional_tensor(weight_opt);
187 const Tensor& weight = *weight_maybe_owned;
188 const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
189
190 const auto N = input.sym_size(0);
191 const auto C = input.sym_size(1);
192 check_group_norm_inputs(input, weight, bias, C, num_groups);
193
194 const auto input_shape = input.sym_sizes();
195 const auto HxW =
196 c10::multiply_integers(input_shape.slice(2));
197
198 const Tensor kEmpty;
199 auto memory_format = input.suggest_memory_format();
200 const auto& X = input.device().is_cpu() ? input.contiguous(memory_format) : input.contiguous();
201 const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty;
202 const auto& beta = bias.defined() ? bias.contiguous() : kEmpty;
203 TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C);
204 TORCH_CHECK(!beta.defined() || beta.sym_numel() == C);
205 return std::get<0>(
206 at::native_group_norm_symint(X, gamma, beta, N, C, HxW, num_groups, eps));
207 }
208
209 DEFINE_DISPATCH(GroupNormKernel);
210 DEFINE_DISPATCH(GroupNormBackwardKernel);
211
212 // Ported from pytorch/xla repo
math_group_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)213 std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
214 const Tensor& input,
215 const std::optional<Tensor>& weight_opt,
216 const std::optional<Tensor>& bias_opt,
217 int64_t N,
218 int64_t C,
219 int64_t HxW,
220 int64_t group,
221 double eps) {
222 // See [Note: hacky wrapper removal for optional tensor]
223 c10::MaybeOwned<Tensor> weight_maybe_owned =
224 at::borrow_from_optional_tensor(weight_opt);
225 const Tensor& weight = *weight_maybe_owned;
226 const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
227
228 auto input_shape = input.sizes();
229 at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1});
230 auto outputs = at::native_batch_norm(
231 input_reshaped,
232 /*weight=*/{},
233 /*bias=*/{},
234 /*running_mean=*/{},
235 /*running_var=*/{},
236 /*training=*/true,
237 /*momentum=*/0,
238 eps);
239 at::Tensor out = std::get<0>(outputs);
240 out = out.view(input_shape);
241 std::vector<int64_t> affine_param_shape(input.dim(), 1);
242 affine_param_shape[1] = C;
243 if (weight.defined() && bias.defined()) {
244 out = bias.view(affine_param_shape)
245 .addcmul(out, weight.view(affine_param_shape), 1);
246 } else if (weight.defined()) {
247 out = out.mul(weight.view(affine_param_shape));
248 } else if (bias.defined()) {
249 out = out.add(bias.view(affine_param_shape));
250 }
251 // convert mean/std to have the same dtype as input.
252 // This follows the same behavior as the CPU and CUDA kernels.
253 at::Tensor mean = std::get<1>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
254 at::Tensor rstd = std::get<2>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
255 return std::make_tuple(out, mean, rstd);
256 }
257 } // namespace at::native
258