1 #pragma once 2 #include <array> 3 #include <cstdint> 4 5 namespace at { 6 class TensorBase; 7 } 8 9 namespace at { 10 namespace native { 11 12 void launch_grid_sampler_2d_forward_kernel( 13 const TensorBase &output, const TensorBase &input, const TensorBase &grid, 14 int64_t interpolation_mode, int64_t padding_mode, bool align_corners); 15 16 void launch_grid_sampler_3d_forward_kernel( 17 const TensorBase &output, const TensorBase &input, const TensorBase &grid, 18 int64_t interpolation_mode, int64_t padding_mode, bool align_corners); 19 20 void launch_grid_sampler_2d_backward_kernel( 21 const TensorBase &grad_input, const TensorBase &grad_grid, 22 const TensorBase &grad_output, const TensorBase &input, 23 const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, 24 bool align_corners, std::array<bool, 2> output_mask); 25 26 void launch_grid_sampler_3d_backward_kernel( 27 const TensorBase &grad_input, const TensorBase &grad_grid, 28 const TensorBase &grad_output, const TensorBase &input, 29 const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, 30 bool align_corners, std::array<bool, 2> output_mask); 31 32 }} // namespace at::native 33