1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <c10/core/ScalarType.h> 5 #include <cstdint> 6 7 namespace at::native { 8 9 using unfold2d_copy_fn = void (*)( 10 ScalarType dtype, 11 void *finput, 12 const void *input, 13 int64_t kH, 14 int64_t kW, 15 int64_t dH, 16 int64_t dW, 17 int64_t padH, 18 int64_t padW, 19 int64_t n_input_plane, 20 int64_t input_height, 21 int64_t input_width, 22 int64_t output_height, 23 int64_t output_width, 24 bool is_channels_last 25 ); 26 27 using unfold2d_acc_fn = void (*)( 28 ScalarType dtype, 29 void *finput, 30 void *input, 31 int64_t kH, 32 int64_t kW, 33 int64_t dH, 34 int64_t dW, 35 int64_t padH, 36 int64_t padW, 37 int64_t n_input_plane, 38 int64_t input_height, 39 int64_t input_width, 40 int64_t output_height, 41 int64_t output_width, 42 bool is_channels_last 43 ); 44 45 DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub); 46 DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub); 47 48 } // namespace at::native 49