1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <ATen/native/GridSamplerUtils.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/cudnn_grid_sampler_backward_native.h>
12 #include <ATen/ops/cudnn_grid_sampler_native.h>
13 #include <ATen/ops/empty.h>
14 #endif
15
16 #if !AT_CUDNN_ENABLED()
17
18 namespace at {
19 namespace native {
20
21 // See Note [ATen preprocessor philosophy]
22
cudnn_grid_sampler_forward(const Tensor & input_t,const Tensor & grid_t)23 Tensor cudnn_grid_sampler_forward(const Tensor& input_t, const Tensor& grid_t) {
24 AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
25 }
26
cudnn_grid_sampler_backward(const Tensor & input_t,const Tensor & grid_t,const Tensor & grad_output_t)27 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
28 const Tensor& input_t,
29 const Tensor& grid_t,
30 const Tensor& grad_output_t) {
31 AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
32 }
33
34 } // namespace native
35 } // namespace at
36
37 #else // AT_CUDNN_ENABLED
38
39 #include <ATen/cuda/Exceptions.h>
40 #include <ATen/cudnn/Descriptors.h>
41 #include <ATen/cudnn/Types.h>
42 #include <ATen/cudnn/Utils.h>
43
44 #include <ATen/TensorUtils.h>
45 #include <c10/util/irange.h>
46
47 // TODO: descriptor checking
48
49 namespace at {
50 namespace native {
51
52 namespace {
53
setSamplerDescriptor(SpatialTransformerDescriptor & desc,cudnnDataType_t dataType,const at::Tensor & tensor)54 void setSamplerDescriptor(
55 SpatialTransformerDescriptor& desc,
56 cudnnDataType_t dataType,
57 const at::Tensor& tensor) {
58 int inputSize[4] = {0};
59 for (const auto i : c10::irange(tensor.dim())) {
60 inputSize[i] = (int)tensor.size(i);
61 }
62 desc.set(dataType, 4, inputSize);
63 }
64
checkGridSize(CheckedFrom c,TensorArg grid,TensorArg input)65 void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) {
66 // assert size of grid is n*h*w*2
67 // FYI: grid is between [-1, 1], where -1 left most pixel,
68 // 1 represents right most pixel (and hence 0 is the center pixel)
69 // if grid has values >1 or <-1, those values are ignored
70 checkContiguous(c, grid);
71 checkDim(c, grid, 4);
72 // TODO: Maybe more user friendly to report where the expected size
73 // came from
74 checkSize(c, grid, 0, input->size(0));
75 checkSize(c, grid, 3, 2);
76 }
77
78 } // namespace
79
cudnn_grid_sampler_forward(const Tensor & input_t,const Tensor & grid_t)80 Tensor cudnn_grid_sampler_forward(const Tensor& input_t, const Tensor& grid_t) {
81 // See NOTE [ grid_sampler Native Functions ].
82 // Add checks here in case this is called instead of grid_sampler.
83 check_grid_sampler_common(input_t, grid_t);
84 TORCH_CHECK(
85 cond_cudnn_grid_sampler(input_t, grid_t),
86 "Invalid arguments to cudnn_grid_sampler_forward");
87
88 auto input_contig = contiguousIfZeroInStrides(input_t);
89 auto grid_contig = grid_t.contiguous();
90 TensorArg input{input_contig, "input", 1}, grid{grid_contig, "grid", 2};
91 CheckedFrom c = "cudnn_grid_sampler_forward";
92 checkAllSameGPU(c, {input, grid});
93 checkAllSameType(c, {input, grid});
94 checkGridSize(c, grid, input);
95 checkDim(c, input, 4);
96
97 auto output_t = at::empty({0}, input->options());
98 output_t.resize_(
99 {input->size(0), input->size(1), grid->size(1), grid->size(2)});
100
101 TensorDescriptor idesc{*input}; // input descriptor
102 TensorDescriptor odesc{output_t}; // output descriptor
103 SpatialTransformerDescriptor desc; // sampler descriptor
104
105 auto handle = getCudnnHandle();
106 auto dataType = getCudnnDataType(*input);
107 setSamplerDescriptor(desc, dataType, output_t);
108
109 Constant one(dataType, 1);
110 Constant zero(dataType, 0);
111 AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
112 handle,
113 desc.desc(),
114 &one,
115 idesc.desc(),
116 input->const_data_ptr(),
117 grid->const_data_ptr(),
118 &zero,
119 odesc.desc(),
120 output_t.data_ptr()));
121
122 return output_t;
123 }
124
125 // NB: CuDNN does not support output mask; you always get both
126 // gradients.
cudnn_grid_sampler_backward(const Tensor & input_t,const Tensor & grid_t,const Tensor & grad_output_t)127 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
128 const Tensor& input_t,
129 const Tensor& grid_t,
130 const Tensor& grad_output_t) {
131 // See NOTE [ grid_sampler Native Functions ].
132 // Add checks here in case this is called instead of grid_sampler.
133 check_grid_sampler_common(input_t, grid_t);
134 TORCH_CHECK(
135 cond_cudnn_grid_sampler(input_t, grid_t),
136 "Invalid arguments to cudnn_grid_sampler_backward");
137
138 auto input_contig = contiguousIfZeroInStrides(input_t);
139 auto grid_contig = grid_t.contiguous();
140 auto grad_output_contig = contiguousIfZeroInStrides(grad_output_t);
141 TensorArg input{input_contig, "input", 1}, grid{grid_contig, "grid", 2},
142 grad_output{grad_output_contig, "grad_output", 3};
143 CheckedFrom c = "cudnn_grid_sampler_backward";
144 checkAllSameGPU(c, {input, grad_output, grid});
145 checkGridSize(c, grid, input);
146 checkDim(c, input, 4);
147 checkDim(c, grad_output, 4);
148
149 auto grad_input_t = at::empty({0}, input->options());
150 grad_input_t.resize_(input->sizes());
151 auto grad_grid_t = at::empty({0}, grid->options());
152 grad_grid_t.resize_(grid->sizes());
153
154 TensorDescriptor idesc{*input}; // input descriptor
155 TensorDescriptor odesc{*grad_output}; // grad_output descriptor
156 TensorDescriptor gdesc{grad_input_t}; // grad_input descriptor
157 SpatialTransformerDescriptor desc; // sampler descriptor
158
159 auto handle = getCudnnHandle();
160 auto dataType = getCudnnDataType(*input);
161 setSamplerDescriptor(desc, dataType, *grad_output);
162
163 Constant one(dataType, 1);
164 Constant zero(dataType, 0);
165 AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
166 handle,
167 desc.desc(),
168 &one,
169 idesc.desc(),
170 input->const_data_ptr(),
171 &zero,
172 gdesc.desc(),
173 grad_input_t.data_ptr(),
174 &one,
175 odesc.desc(),
176 grad_output->const_data_ptr(),
177 // intriguingly, the outputs don't need descriptors
178 grid->const_data_ptr(),
179 &zero,
180 grad_grid_t.data_ptr()));
181
182 return std::tuple<Tensor, Tensor>{grad_input_t, grad_grid_t};
183 }
184
185 } // namespace native
186 } // namespace at
187
188 #endif
189