1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 5 #include <array> 6 #include <cstdint> 7 8 namespace at { 9 class TensorBase; 10 } 11 12 namespace at::native { 13 14 using forward_2d_fn = void (*) ( 15 const TensorBase &output, 16 const TensorBase &input, 17 const TensorBase &grid, 18 int64_t interpolation_mode, 19 int64_t padding_mode, 20 bool align_corners); 21 using backward_2d_fn = void (*) ( 22 const TensorBase &grad_input, 23 const TensorBase &grad_grid, 24 const TensorBase &grad_output, 25 const TensorBase &input, 26 const TensorBase &grid, 27 int64_t interpolation_mode, 28 int64_t padding_mode, 29 bool align_corners, 30 std::array<bool, 2> output_mask); 31 DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel); 32 DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel); 33 34 } // namespace at::native 35