xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/group_norm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <cstdint>
5 
6 namespace at {
7 class Tensor;
8 
9 namespace native {
10 
11 using forward_fn = void (*)(
12     const Tensor& /* X */,
13     const Tensor& /* gamma */,
14     const Tensor& /* beta */,
15     int64_t /* N */,
16     int64_t /* C */,
17     int64_t /* HxW */,
18     int64_t /* group */,
19     double /* eps */,
20     Tensor& /* Y */,
21     Tensor& /* mean */,
22     Tensor& /* rstd */);
23 
24 using backward_fn = void (*)(
25     const Tensor& /* dY */,
26     const Tensor& /* X */,
27     const Tensor& /* mean */,
28     const Tensor& /* rstd */,
29     const Tensor& /* gamma */,
30     int64_t /* N */,
31     int64_t /* C */,
32     int64_t /* HxW */,
33     int64_t /* group */,
34     Tensor& /* dX */,
35     Tensor& /* dgamma */,
36     Tensor& /* dbeta */);
37 
38 DECLARE_DISPATCH(forward_fn, GroupNormKernel);
39 DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel);
40 
41 } // namespace native
42 } // namespace at
43