xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/GridSampler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <ATen/native/GridSamplerUtils.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/cudnn_grid_sampler_backward_native.h>
12 #include <ATen/ops/cudnn_grid_sampler_native.h>
13 #include <ATen/ops/empty.h>
14 #endif
15 
16 #if !AT_CUDNN_ENABLED()
17 
18 namespace at {
19 namespace native {
20 
21 // See Note [ATen preprocessor philosophy]
22 
cudnn_grid_sampler_forward(const Tensor & input_t,const Tensor & grid_t)23 Tensor cudnn_grid_sampler_forward(const Tensor& input_t, const Tensor& grid_t) {
24   AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
25 }
26 
cudnn_grid_sampler_backward(const Tensor & input_t,const Tensor & grid_t,const Tensor & grad_output_t)27 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
28     const Tensor& input_t,
29     const Tensor& grid_t,
30     const Tensor& grad_output_t) {
31   AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
32 }
33 
34 } // namespace native
35 } // namespace at
36 
37 #else // AT_CUDNN_ENABLED
38 
39 #include <ATen/cuda/Exceptions.h>
40 #include <ATen/cudnn/Descriptors.h>
41 #include <ATen/cudnn/Types.h>
42 #include <ATen/cudnn/Utils.h>
43 
44 #include <ATen/TensorUtils.h>
45 #include <c10/util/irange.h>
46 
47 // TODO: descriptor checking
48 
49 namespace at {
50 namespace native {
51 
52 namespace {
53 
setSamplerDescriptor(SpatialTransformerDescriptor & desc,cudnnDataType_t dataType,const at::Tensor & tensor)54 void setSamplerDescriptor(
55     SpatialTransformerDescriptor& desc,
56     cudnnDataType_t dataType,
57     const at::Tensor& tensor) {
58   int inputSize[4] = {0};
59   for (const auto i : c10::irange(tensor.dim())) {
60     inputSize[i] = (int)tensor.size(i);
61   }
62   desc.set(dataType, 4, inputSize);
63 }
64 
checkGridSize(CheckedFrom c,TensorArg grid,TensorArg input)65 void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) {
66   // assert size of grid is n*h*w*2
67   // FYI: grid is between [-1, 1], where -1 left most pixel,
68   // 1 represents right most pixel (and hence 0 is the center pixel)
69   // if grid has values >1 or <-1, those values are ignored
70   checkContiguous(c, grid);
71   checkDim(c, grid, 4);
72   // TODO: Maybe more user friendly to report where the expected size
73   // came from
74   checkSize(c, grid, 0, input->size(0));
75   checkSize(c, grid, 3, 2);
76 }
77 
78 } // namespace
79 
cudnn_grid_sampler_forward(const Tensor & input_t,const Tensor & grid_t)80 Tensor cudnn_grid_sampler_forward(const Tensor& input_t, const Tensor& grid_t) {
81   // See NOTE [ grid_sampler Native Functions ].
82   // Add checks here in case this is called instead of grid_sampler.
83   check_grid_sampler_common(input_t, grid_t);
84   TORCH_CHECK(
85       cond_cudnn_grid_sampler(input_t, grid_t),
86       "Invalid arguments to cudnn_grid_sampler_forward");
87 
88   auto input_contig = contiguousIfZeroInStrides(input_t);
89   auto grid_contig = grid_t.contiguous();
90   TensorArg input{input_contig, "input", 1}, grid{grid_contig, "grid", 2};
91   CheckedFrom c = "cudnn_grid_sampler_forward";
92   checkAllSameGPU(c, {input, grid});
93   checkAllSameType(c, {input, grid});
94   checkGridSize(c, grid, input);
95   checkDim(c, input, 4);
96 
97   auto output_t = at::empty({0}, input->options());
98   output_t.resize_(
99       {input->size(0), input->size(1), grid->size(1), grid->size(2)});
100 
101   TensorDescriptor idesc{*input}; // input descriptor
102   TensorDescriptor odesc{output_t}; // output descriptor
103   SpatialTransformerDescriptor desc; // sampler descriptor
104 
105   auto handle = getCudnnHandle();
106   auto dataType = getCudnnDataType(*input);
107   setSamplerDescriptor(desc, dataType, output_t);
108 
109   Constant one(dataType, 1);
110   Constant zero(dataType, 0);
111   AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
112       handle,
113       desc.desc(),
114       &one,
115       idesc.desc(),
116       input->const_data_ptr(),
117       grid->const_data_ptr(),
118       &zero,
119       odesc.desc(),
120       output_t.data_ptr()));
121 
122   return output_t;
123 }
124 
125 // NB: CuDNN does not support output mask; you always get both
126 // gradients.
cudnn_grid_sampler_backward(const Tensor & input_t,const Tensor & grid_t,const Tensor & grad_output_t)127 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
128     const Tensor& input_t,
129     const Tensor& grid_t,
130     const Tensor& grad_output_t) {
131   // See NOTE [ grid_sampler Native Functions ].
132   // Add checks here in case this is called instead of grid_sampler.
133   check_grid_sampler_common(input_t, grid_t);
134   TORCH_CHECK(
135       cond_cudnn_grid_sampler(input_t, grid_t),
136       "Invalid arguments to cudnn_grid_sampler_backward");
137 
138   auto input_contig = contiguousIfZeroInStrides(input_t);
139   auto grid_contig = grid_t.contiguous();
140   auto grad_output_contig = contiguousIfZeroInStrides(grad_output_t);
141   TensorArg input{input_contig, "input", 1}, grid{grid_contig, "grid", 2},
142       grad_output{grad_output_contig, "grad_output", 3};
143   CheckedFrom c = "cudnn_grid_sampler_backward";
144   checkAllSameGPU(c, {input, grad_output, grid});
145   checkGridSize(c, grid, input);
146   checkDim(c, input, 4);
147   checkDim(c, grad_output, 4);
148 
149   auto grad_input_t = at::empty({0}, input->options());
150   grad_input_t.resize_(input->sizes());
151   auto grad_grid_t = at::empty({0}, grid->options());
152   grad_grid_t.resize_(grid->sizes());
153 
154   TensorDescriptor idesc{*input}; // input descriptor
155   TensorDescriptor odesc{*grad_output}; // grad_output descriptor
156   TensorDescriptor gdesc{grad_input_t}; // grad_input descriptor
157   SpatialTransformerDescriptor desc; // sampler descriptor
158 
159   auto handle = getCudnnHandle();
160   auto dataType = getCudnnDataType(*input);
161   setSamplerDescriptor(desc, dataType, *grad_output);
162 
163   Constant one(dataType, 1);
164   Constant zero(dataType, 0);
165   AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
166       handle,
167       desc.desc(),
168       &one,
169       idesc.desc(),
170       input->const_data_ptr(),
171       &zero,
172       gdesc.desc(),
173       grad_input_t.data_ptr(),
174       &one,
175       odesc.desc(),
176       grad_output->const_data_ptr(),
177       // intriguingly, the outputs don't need descriptors
178       grid->const_data_ptr(),
179       &zero,
180       grad_grid_t.data_ptr()));
181 
182   return std::tuple<Tensor, Tensor>{grad_input_t, grad_grid_t};
183 }
184 
185 } // namespace native
186 } // namespace at
187 
188 #endif
189