1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/GridSampler.h>
3
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty.h>
9 #include <ATen/ops/empty_like.h>
10 #include <ATen/ops/grid_sampler_2d_backward_native.h>
11 #include <ATen/ops/grid_sampler_2d_native.h>
12 #include <ATen/ops/grid_sampler_3d_backward_native.h>
13 #include <ATen/ops/grid_sampler_3d_native.h>
14 #include <ATen/ops/zeros_like.h>
15 #endif
16
17 namespace at::native {
18
grid_sampler_2d_cuda(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)19 Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid,
20 int64_t interpolation_mode, int64_t padding_mode,
21 bool align_corners) {
22 auto in_size = input.sizes();
23 auto grid_size = grid.sizes();
24 auto output = at::empty(
25 {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
26 launch_grid_sampler_2d_forward_kernel(
27 output, input, grid, interpolation_mode, padding_mode, align_corners);
28 return output;
29 }
30
grid_sampler_3d_cuda(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)31 Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid,
32 int64_t interpolation_mode, int64_t padding_mode,
33 bool align_corners) {
34 auto in_size = input.sizes();
35 auto grid_size = grid.sizes();
36 auto output = at::empty(
37 {in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]},
38 input.options());
39 launch_grid_sampler_3d_forward_kernel(
40 output, input, grid, interpolation_mode, padding_mode, align_corners);
41 return output;
42 }
43
44 std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)45 grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
46 const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
47 bool align_corners, std::array<bool, 2> output_mask) {
48 auto input_requires_grad = output_mask[0];
49 Tensor grad_input = ([&]() {
50 if (input_requires_grad) {
51 return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
52 } else {
53 return Tensor();
54 }
55 })();
56 auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
57 launch_grid_sampler_2d_backward_kernel(
58 grad_input, grad_grid, grad_output, input,
59 grid, interpolation_mode, padding_mode, align_corners, output_mask);
60 return std::make_tuple(grad_input, grad_grid);
61 }
62
63 std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)64 grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
65 const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
66 bool align_corners, std::array<bool,2> output_mask) {
67 auto input_requires_grad = output_mask[0];
68 Tensor grad_input = ([&]() {
69 if (input_requires_grad) {
70 return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
71 } else {
72 return Tensor();
73 }
74 })();
75 auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
76 launch_grid_sampler_3d_backward_kernel(
77 grad_input, grad_grid, grad_output, input,
78 grid, interpolation_mode, padding_mode, align_corners, output_mask);
79 return std::make_tuple(grad_input, grad_grid);
80 }
81
82 } // namespace at::native
83