1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <cstdint> 5 6 namespace at { 7 class Tensor; 8 9 namespace native { 10 11 using forward_fn = void (*)( 12 const Tensor& /* X */, 13 const Tensor& /* gamma */, 14 const Tensor& /* beta */, 15 int64_t /* N */, 16 int64_t /* C */, 17 int64_t /* HxW */, 18 int64_t /* group */, 19 double /* eps */, 20 Tensor& /* Y */, 21 Tensor& /* mean */, 22 Tensor& /* rstd */); 23 24 using backward_fn = void (*)( 25 const Tensor& /* dY */, 26 const Tensor& /* X */, 27 const Tensor& /* mean */, 28 const Tensor& /* rstd */, 29 const Tensor& /* gamma */, 30 int64_t /* N */, 31 int64_t /* C */, 32 int64_t /* HxW */, 33 int64_t /* group */, 34 Tensor& /* dX */, 35 Tensor& /* dgamma */, 36 Tensor& /* dbeta */); 37 38 DECLARE_DISPATCH(forward_fn, GroupNormKernel); 39 DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel); 40 41 } // namespace native 42 } // namespace at 43