xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/GridSampler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/GridSampler.h>
3 
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty.h>
9 #include <ATen/ops/empty_like.h>
10 #include <ATen/ops/grid_sampler_2d_backward_native.h>
11 #include <ATen/ops/grid_sampler_2d_native.h>
12 #include <ATen/ops/grid_sampler_3d_backward_native.h>
13 #include <ATen/ops/grid_sampler_3d_native.h>
14 #include <ATen/ops/zeros_like.h>
15 #endif
16 
17 namespace at::native {
18 
grid_sampler_2d_cuda(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)19 Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid,
20                             int64_t interpolation_mode, int64_t padding_mode,
21                             bool align_corners) {
22   auto in_size = input.sizes();
23   auto grid_size = grid.sizes();
24   auto output = at::empty(
25       {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
26   launch_grid_sampler_2d_forward_kernel(
27       output, input, grid, interpolation_mode, padding_mode, align_corners);
28   return output;
29 }
30 
grid_sampler_3d_cuda(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)31 Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid,
32                             int64_t interpolation_mode, int64_t padding_mode,
33                             bool align_corners) {
34   auto in_size = input.sizes();
35   auto grid_size = grid.sizes();
36   auto output = at::empty(
37       {in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]},
38       input.options());
39   launch_grid_sampler_3d_forward_kernel(
40       output, input, grid, interpolation_mode, padding_mode, align_corners);
41   return output;
42 }
43 
44 std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)45 grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
46                               const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
47                               bool align_corners, std::array<bool, 2> output_mask) {
48   auto input_requires_grad = output_mask[0];
49   Tensor grad_input = ([&]() {
50     if (input_requires_grad) {
51       return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
52     } else {
53       return Tensor();
54     }
55   })();
56   auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
57   launch_grid_sampler_2d_backward_kernel(
58       grad_input, grad_grid, grad_output, input,
59       grid, interpolation_mode, padding_mode, align_corners, output_mask);
60   return std::make_tuple(grad_input, grad_grid);
61 }
62 
63 std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)64 grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
65                               const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
66                               bool align_corners, std::array<bool,2> output_mask) {
67   auto input_requires_grad = output_mask[0];
68   Tensor grad_input = ([&]() {
69     if (input_requires_grad) {
70       return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
71     } else {
72       return Tensor();
73     }
74   })();
75   auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
76   launch_grid_sampler_3d_backward_kernel(
77       grad_input, grad_grid, grad_output, input,
78       grid, interpolation_mode, padding_mode, align_corners, output_mask);
79   return std::make_tuple(grad_input, grad_grid);
80 }
81 
82 }  // namespace at::native
83