xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/GridSamplerKernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 
5 #include <array>
6 #include <cstdint>
7 
8 namespace at {
9 class TensorBase;
10 }
11 
12 namespace at::native {
13 
14 using forward_2d_fn = void (*) (
15     const TensorBase &output,
16     const TensorBase &input,
17     const TensorBase &grid,
18     int64_t interpolation_mode,
19     int64_t padding_mode,
20     bool align_corners);
21 using backward_2d_fn = void (*) (
22     const TensorBase &grad_input,
23     const TensorBase &grad_grid,
24     const TensorBase &grad_output,
25     const TensorBase &input,
26     const TensorBase &grid,
27     int64_t interpolation_mode,
28     int64_t padding_mode,
29     bool align_corners,
30     std::array<bool, 2> output_mask);
31 DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
32 DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
33 
34 } // namespace at::native
35