1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/native/ConvUtils.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/empty_like.h>
12 #include <ATen/ops/empty_native.h>
13 #include <ATen/ops/miopen_convolution_add_relu_native.h>
14 #include <ATen/ops/miopen_convolution_native.h>
15 #include <ATen/ops/miopen_convolution_relu_native.h>
16 #include <ATen/ops/miopen_convolution_transpose_native.h>
17 #include <ATen/ops/miopen_depthwise_convolution_native.h>
18 #include <ATen/ops/squeeze.h>
19 #include <ATen/ops/sum.h>
20 #include <ATen/ops/zeros.h>
21 #endif
22
23 // TODO: Remove the condition on AT_ROCM_ENABLED entirely,
24 // don't build this file as part of CPU build.
25 #include <ATen/cuda/CUDAConfig.h>
26
27 #if !AT_ROCM_ENABLED()
28
29 namespace at { namespace native {
30
31 // See Note [ATen preprocessor philosophy]
32
miopen_convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)33 at::Tensor miopen_convolution(
34 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt /* optional */,
35 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
36 int64_t groups, bool benchmark, bool deterministic) {
37 AT_ERROR("miopen_convolution: ATen not compiled with MIOpen support");
38 }
39
miopen_convolution_backward_input(IntArrayRef input_size,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)40 at::Tensor miopen_convolution_backward_input(
41 IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
42 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
43 bool benchmark, bool deterministic) {
44 AT_ERROR("miopen_convolution_backward_input: ATen not compiled with MIOpen support");
45 }
46
miopen_convolution_backward_weight(IntArrayRef weight_size,const at::Tensor & grad_output,const at::Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)47 at::Tensor miopen_convolution_backward_weight(
48 IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
49 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
50 bool benchmark, bool deterministic) {
51 AT_ERROR("miopen_convolution_backward_weight: ATen not compiled with MIOpen support");
52 }
53
miopen_convolution_backward_bias(const at::Tensor & grad_output)54 at::Tensor miopen_convolution_backward_bias(
55 const at::Tensor& grad_output) {
56 AT_ERROR("miopen_convolution_backward_bias: ATen not compiled with MIOpen support");
57 }
58
miopen_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)59 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_backward(
60 const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
61 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
62 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
63 AT_ERROR("miopen_convolution_backward: ATen not compiled with MIOpen support");
64 }
65
miopen_convolution_transpose(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)66 at::Tensor miopen_convolution_transpose(
67 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt /* optional */,
68 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
69 int64_t groups, bool benchmark, bool deterministic) {
70 AT_ERROR("miopen_convolution_transpose: ATen not compiled with MIOpen support");
71 }
72
miopen_convolution_transpose_backward_input(const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)73 at::Tensor miopen_convolution_transpose_backward_input(
74 const at::Tensor& grad_output, const at::Tensor& weight,
75 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
76 int64_t groups, bool benchmark, bool deterministic) {
77 AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support");
78 }
79
miopen_convolution_transpose_backward_weight(IntArrayRef weight_size,const at::Tensor & grad_output,const at::Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)80 at::Tensor miopen_convolution_transpose_backward_weight(
81 IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
82 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
83 bool benchmark, bool deterministic) {
84 AT_ERROR("miopen_convolution_transpose_backward_weight: ATen not compiled with MIOpen support");
85 }
86
miopen_convolution_transpose_backward(const at::Tensor & input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)87 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_transpose_backward(
88 const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
89 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
90 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
91 AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support");
92 }
93
miopen_depthwise_convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)94 at::Tensor miopen_depthwise_convolution(
95 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt /* optional */,
96 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
97 int64_t groups, bool benchmark, bool deterministic) {
98 AT_ERROR("miopen_depthwise_convolution: ATen not compiled with MIOpen support");
99 }
100
miopen_depthwise_convolution_backward_input(IntArrayRef input_size,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)101 at::Tensor miopen_depthwise_convolution_backward_input(
102 IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
103 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
104 bool benchmark, bool deterministic) {
105 AT_ERROR("miopen_depthwise_convolution_backward_input: ATen not compiled with MIOpen support");
106 }
107
miopen_depthwise_convolution_backward_weight(IntArrayRef weight_size,const at::Tensor & grad_output,const at::Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)108 at::Tensor miopen_depthwise_convolution_backward_weight(
109 IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
110 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
111 bool benchmark, bool deterministic) {
112 AT_ERROR("miopen_depthwise_convolution_backward_weight: ATen not compiled with MIOpen support");
113 }
114
miopen_depthwise_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)115 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_depthwise_convolution_backward(
116 const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
117 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
118 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
119 AT_ERROR("miopen_depthwise_convolution_backward: ATen not compiled with MIOpen support");
120 }
121
122
miopen_convolution_add_relu(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & z,const std::optional<Scalar> & alpha,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)123 at::Tensor miopen_convolution_add_relu(
124 const at::Tensor& input, const at::Tensor& weight, const at::Tensor& z,
125 const std::optional<Scalar>& alpha, const std::optional<Tensor>& bias, IntArrayRef stride,
126 IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
127 AT_ERROR("miopen_convolution_add_relu: ATen not compiled with MIOpen support");
128 }
129
miopen_convolution_relu(const at::Tensor & input,const at::Tensor & weight,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)130 at::Tensor miopen_convolution_relu(
131 const at::Tensor& input, const at::Tensor& weight, const std::optional<Tensor>& bias,
132 IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
133 AT_ERROR("miopen_convolution_relu: ATen not compiled with MIOpen support");
134 }
135
136 }}
137
138 #else // AT_ROCM_ENABLED
139
140 #include <ATen/miopen/miopen-wrapper.h>
141 #include <ATen/miopen/Descriptors.h>
142 #include <ATen/miopen/Types.h>
143 #include <ATen/miopen/Utils.h>
144 #include <ATen/hip/EmptyTensor.h>
145
146 #include <ATen/TensorUtils.h>
147 #include <ATen/native/ConvUtils.h>
148 #include <c10/util/irange.h>
149
150 #include <c10/hip/HIPCachingAllocator.h>
151
152 #include <functional>
153 #include <iterator>
154 #include <sstream>
155 #include <algorithm>
156 #include <memory>
157 #include <mutex>
158 #include <stdint.h>
159 #include <unordered_map>
160
161 #define AT_MIOPEN_MAX_SOLUTIONS 10
162
163 namespace at { namespace native {
164
narrowGroup(const Tensor & t,int dim,int group_idx,int64_t groups)165 Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) {
166 auto group_size = t.size(dim) / groups;
167 return t.narrow(dim, group_idx * group_size, group_size);
168 }
169
170 // This POD struct is used to let us easily compute hashes of the
171 // parameters
172 struct ConvolutionParams
173 {
174 miopenHandle_t handle;
175 miopenDataType_t dataType;
176 int input_size[2 + max_dim];
177 int input_stride[2 + max_dim];
178 int weight_size[2 + max_dim];
179 int padding[max_dim];
180 int stride[max_dim];
181 int dilation[max_dim];
182 int64_t groups;
183 bool deterministic;
184 int device_id; //This is needed to distinguish between miopen handles of multiple gpus.
185 // NB: transposed purposely omitted: transposed just swaps
186 // forward and backward, so you can reuse the benchmark entry,
187 };
188 // ConvolutionParams must be a POD because we read out its memory
189 // contenst as char* when hashing
190 static_assert(std::is_standard_layout<ConvolutionParams>::value, "ConvolutionParams not POD");
191
setConvolutionParams(ConvolutionParams * params,miopenHandle_t handle,const at::Tensor & input,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool deterministic)192 void setConvolutionParams(
193 ConvolutionParams* params, miopenHandle_t handle,
194 const at::Tensor& input, const at::Tensor& weight,
195 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
196 int64_t groups, bool deterministic) {
197
198 miopenDataType_t dataType = getMiopenDataType(input);
199 memset(params, 0, sizeof(ConvolutionParams));
200 params->dataType = dataType;
201 params->handle = handle;
202 // ASSERT(weight.dim() == input.dim())
203 for (int i = 0; i != input.dim(); ++i) {
204 params->input_size[i] = (int) input.size(i);
205 params->input_stride[i] = (int) input.stride(i);
206 params->weight_size[i] = (int) weight.size(i);
207 }
208 // ASSERT(padding.size() == stride.size())
209 // ASSERT(padding.size() == dilation.size())
210 for (size_t i = 0; i != padding.size(); ++i) {
211 params->padding[i] = padding[i];
212 params->stride[i] = stride[i];
213 params->dilation[i] = dilation[i];
214 }
215 params->groups = groups;
216 params->deterministic = deterministic;
217 int device_id;
218 HIP_CHECK(hipGetDevice(&device_id));
219 params->device_id = device_id;
220 }
221
222 // Convenience struct for passing around descriptors and data
223 // pointers
224 struct ConvolutionArgs {
225 miopenHandle_t handle;
226 ConvolutionParams params;
227 TensorDescriptor idesc, odesc;
228 FilterDescriptor wdesc;
229 const Tensor& input, output, weight;
230 ConvolutionDescriptor cdesc;
231
ConvolutionArgsat::native::ConvolutionArgs232 ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) {
233 }
234 };
235
236 // ---------------------------------------------------------------------
237 //
238 // Benchmarking
239 //
240 // ---------------------------------------------------------------------
241
242 // Hashing machinery for ConvolutionParams
243 struct ParamsHash {
operator ()at::native::ParamsHash244 std::size_t operator()(const ConvolutionParams& params) const {
245 auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
246 uint32_t value = 0x811C9DC5;
247 for (const auto i : c10::irange((int)sizeof(ConvolutionParams))) {
248 value ^= ptr[i];
249 value *= 0x01000193;
250 }
251 return (size_t)value;
252 }
253 };
254
255 struct ParamsEqual {
operator ()at::native::ParamsEqual256 bool operator()(const ConvolutionParams& a, const ConvolutionParams& b) const {
257 auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
258 auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
259 return memcmp(ptr1, ptr2, sizeof(ConvolutionParams)) == 0;
260 }
261 };
262
263 template <typename T>
264 struct BenchmarkCache {
265 std::mutex mutex;
266 std::unordered_map<ConvolutionParams, T, ParamsHash, ParamsEqual> map;
267
findat::native::BenchmarkCache268 bool find(const ConvolutionParams& params, T* results) {
269 std::lock_guard<std::mutex> guard(mutex);
270 auto it = map.find(params);
271 if (it == map.end()) {
272 return false;
273 }
274 *results = it->second;
275 return true;
276 }
277
insertat::native::BenchmarkCache278 void insert(const ConvolutionParams& params, const T& results) {
279 std::lock_guard<std::mutex> guard(mutex);
280 map[params] = results;
281 }
282 };
283
284 BenchmarkCache<miopenConvFwdAlgorithm_t> fwd_algos;
285 BenchmarkCache<miopenConvBwdDataAlgorithm_t> bwd_data_algos;
286 BenchmarkCache<miopenConvBwdWeightsAlgorithm_t> bwd_filter_algos;
287
288 BenchmarkCache<size_t> fwd_wssizes;
289 BenchmarkCache<size_t> bwd_data_wssizes;
290 BenchmarkCache<size_t> bwd_filter_wssizes;
291
292 struct Workspace {
Workspaceat::native::Workspace293 Workspace(size_t size) : size(size), data(NULL) {
294 data = c10::hip::HIPCachingAllocator::raw_alloc(size);
295 }
296 Workspace(const Workspace&) = delete;
297 Workspace(Workspace&&) = default;
298 Workspace& operator=(Workspace&&) = default;
~Workspaceat::native::Workspace299 ~Workspace() {
300 if (data) {
301 c10::hip::HIPCachingAllocator::raw_delete(data);
302 }
303 }
304
305 size_t size;
306 void* data;
307 };
308
309 template<typename algo_t>
310 struct algorithm_search {
311 };
312
getWorkspaceSize(const ConvolutionArgs & args,const miopenConvFwdAlgorithm_t)313 size_t getWorkspaceSize(
314 const ConvolutionArgs& args, const miopenConvFwdAlgorithm_t)
315 {
316 size_t sz = 0;
317 miopenConvolutionForwardGetWorkSpaceSize(
318 args.handle,
319 args.wdesc.desc(),
320 args.idesc.desc(),
321 args.cdesc.desc(),
322 args.odesc.desc(),
323 &sz);
324 return sz;
325 }
getWorkspaceSize(const ConvolutionArgs & args,const miopenConvBwdDataAlgorithm_t)326 size_t getWorkspaceSize(
327 const ConvolutionArgs& args, const miopenConvBwdDataAlgorithm_t)
328 {
329 size_t sz = 0;
330 miopenConvolutionBackwardDataGetWorkSpaceSize(
331 args.handle,
332 args.odesc.desc(),
333 args.wdesc.desc(),
334 args.cdesc.desc(),
335 args.idesc.desc(),
336 &sz);
337 return sz;
338 }
getWorkspaceSize(const ConvolutionArgs & args,const miopenConvBwdWeightsAlgorithm_t)339 size_t getWorkspaceSize(
340 const ConvolutionArgs& args, const miopenConvBwdWeightsAlgorithm_t)
341 {
342 size_t sz = 0;
343 miopenConvolutionBackwardWeightsGetWorkSpaceSize(
344 args.handle,
345 args.odesc.desc(),
346 args.idesc.desc(),
347 args.cdesc.desc(),
348 args.wdesc.desc(),
349 &sz);
350 return sz;
351 }
352
353 template<typename perf_t>
getBestAlgorithm(perf_t * perfResults,bool deterministic,int n_algo)354 perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) {
355 return perfResults[0];
356 }
357
358 template<>
359 struct algorithm_search<miopenConvFwdAlgorithm_t> {
360 using perf_t = miopenConvAlgoPerf_t;
361 using algo_t = miopenConvFwdAlgorithm_t;
362
363 static constexpr auto DEFAULT_ALGO = miopenConvolutionFwdAlgoGEMM;
cacheat::native::algorithm_search364 static BenchmarkCache<algo_t>& cache() { return fwd_algos; }
wsscacheat::native::algorithm_search365 static BenchmarkCache<size_t>& wsscache() { return fwd_wssizes; }
366
findAlgorithmat::native::algorithm_search367 static perf_t findAlgorithm(const ConvolutionArgs& args) {
368 int perf_count;
369 perf_t perf_results;
370 size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
371 Workspace ws(max_ws_size);
372 MIOPEN_CHECK(miopenFindConvolutionForwardAlgorithm(
373 args.handle,
374 args.idesc.desc(), args.input.const_data_ptr(),
375 args.wdesc.desc(), args.weight.const_data_ptr(),
376 args.cdesc.desc(),
377 args.odesc.desc(), args.output.data_ptr(),
378 1, // just return the fastest
379 &perf_count,
380 &perf_results,
381 ws.data,
382 ws.size,
383 false));
384 return perf_results;
385 }
386
getSolutionat::native::algorithm_search387 static miopenConvSolution_t getSolution(const ConvolutionArgs& args, bool force_default) {
388 size_t max_solution_count;
389 size_t solution_count;
390 miopenConvSolution_t solutions[AT_MIOPEN_MAX_SOLUTIONS];
391 MIOPEN_CHECK(miopenConvolutionForwardGetSolutionCount(
392 args.handle,
393 args.wdesc.desc(),
394 args.idesc.desc(),
395 args.cdesc.desc(),
396 args.odesc.desc(),
397 &max_solution_count));
398 if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) {
399 AT_ERROR("miopenConvFwdAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS");
400 }
401 MIOPEN_CHECK(miopenConvolutionForwardGetSolution(
402 args.handle,
403 args.wdesc.desc(),
404 args.idesc.desc(),
405 args.cdesc.desc(),
406 args.odesc.desc(),
407 max_solution_count,
408 &solution_count,
409 solutions));
410
411 if (force_default) {
412 // find default alg
413 for (size_t i=0; i<solution_count; ++i) {
414 if (solutions[i].algorithm == (miopenConvAlgorithm_t)DEFAULT_ALGO) {
415 return solutions[i];
416 }
417 }
418 // default algo was not found, select first algo without workspace requirement
419 for (size_t i=0; i<solution_count; ++i) {
420 if (solutions[i].workspace_size == 0) {
421 return solutions[i];
422 }
423 }
424 // now what? fall through and hope for the best
425 }
426
427 return solutions[0];
428 }
429 };
430
431 template<>
432 struct algorithm_search<miopenConvBwdDataAlgorithm_t> {
433 using perf_t = miopenConvAlgoPerf_t;
434 using algo_t = miopenConvBwdDataAlgorithm_t;
435
436 static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM;
cacheat::native::algorithm_search437 static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; }
wsscacheat::native::algorithm_search438 static BenchmarkCache<size_t>& wsscache() { return bwd_data_wssizes; }
439
findAlgorithmat::native::algorithm_search440 static perf_t findAlgorithm(const ConvolutionArgs& args) {
441 int perf_count;
442 perf_t perf_results;
443 size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
444 Workspace ws(max_ws_size);
445 MIOPEN_CHECK(miopenFindConvolutionBackwardDataAlgorithm(
446 args.handle,
447 args.odesc.desc(), args.output.const_data_ptr(),
448 args.wdesc.desc(), args.weight.const_data_ptr(),
449 args.cdesc.desc(),
450 args.idesc.desc(), args.input.data_ptr(),
451 1, // just return the fastest
452 &perf_count,
453 &perf_results,
454 ws.data,
455 ws.size,
456 false));
457 return perf_results;
458 }
459
getSolutionat::native::algorithm_search460 static miopenConvSolution_t getSolution(const ConvolutionArgs& args, bool force_default) {
461 size_t max_solution_count;
462 size_t solution_count;
463 miopenConvSolution_t solutions[AT_MIOPEN_MAX_SOLUTIONS];
464 MIOPEN_CHECK(miopenConvolutionBackwardDataGetSolutionCount(
465 args.handle,
466 args.odesc.desc(),
467 args.wdesc.desc(),
468 args.cdesc.desc(),
469 args.idesc.desc(),
470 &max_solution_count));
471 if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) {
472 AT_ERROR("miopenConvBwdDataAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS");
473 }
474 MIOPEN_CHECK(miopenConvolutionBackwardDataGetSolution(
475 args.handle,
476 args.odesc.desc(),
477 args.wdesc.desc(),
478 args.cdesc.desc(),
479 args.idesc.desc(),
480 max_solution_count,
481 &solution_count,
482 solutions));
483
484 if (force_default) {
485 // find default alg
486 for (size_t i=0; i<solution_count; ++i) {
487 if (solutions[i].algorithm == (miopenConvAlgorithm_t)DEFAULT_ALGO) {
488 return solutions[i];
489 }
490 }
491 // default algo was not found, select first algo without workspace requirement
492 for (size_t i=0; i<solution_count; ++i) {
493 if (solutions[i].workspace_size == 0) {
494 return solutions[i];
495 }
496 }
497 // now what? fall through and hope for the best
498 }
499
500 return solutions[0];
501 }
502 };
503
504 template<>
505 struct algorithm_search<miopenConvBwdWeightsAlgorithm_t> {
506 using perf_t = miopenConvAlgoPerf_t;
507 using algo_t = miopenConvBwdWeightsAlgorithm_t;
508
509 static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
cacheat::native::algorithm_search510 static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
wsscacheat::native::algorithm_search511 static BenchmarkCache<size_t>& wsscache() { return bwd_filter_wssizes; }
512
findAlgorithmat::native::algorithm_search513 static perf_t findAlgorithm(const ConvolutionArgs& args) {
514 int perf_count;
515 perf_t perf_results;
516 size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
517 Workspace ws(max_ws_size);
518 MIOPEN_CHECK(miopenFindConvolutionBackwardWeightsAlgorithm(
519 args.handle,
520 args.odesc.desc(), args.output.const_data_ptr(),
521 args.idesc.desc(), args.input.const_data_ptr(),
522 args.cdesc.desc(),
523 args.wdesc.desc(), args.weight.data_ptr(),
524 1, // just return the fastest
525 &perf_count,
526 &perf_results,
527 ws.data,
528 ws.size,
529 false));
530 return perf_results;
531 }
532
getSolutionat::native::algorithm_search533 static miopenConvSolution_t getSolution(const ConvolutionArgs& args, bool force_default) {
534 size_t max_solution_count;
535 size_t solution_count;
536 miopenConvSolution_t solutions[AT_MIOPEN_MAX_SOLUTIONS];
537 MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetSolutionCount(
538 args.handle,
539 args.odesc.desc(),
540 args.idesc.desc(),
541 args.cdesc.desc(),
542 args.wdesc.desc(),
543 &max_solution_count));
544 if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) {
545 AT_ERROR("miopenConvBwdWeightsAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS");
546 }
547 MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetSolution(
548 args.handle,
549 args.odesc.desc(),
550 args.idesc.desc(),
551 args.cdesc.desc(),
552 args.wdesc.desc(),
553 max_solution_count,
554 &solution_count,
555 solutions));
556
557 if (force_default) {
558 // find default alg
559 for (size_t i=0; i<solution_count; ++i) {
560 if (solutions[i].algorithm == (miopenConvAlgorithm_t)DEFAULT_ALGO) {
561 return solutions[i];
562 }
563 }
564 // default algo was not found, select first algo without workspace requirement
565 for (size_t i=0; i<solution_count; ++i) {
566 if (solutions[i].workspace_size == 0) {
567 return solutions[i];
568 }
569 }
570 // now what? fall through and hope for the best
571 }
572
573 return solutions[0];
574 }
575 };
576
577 template<typename algo_t>
findAlgorithm(const ConvolutionArgs & args,bool benchmark,algo_t * algo)578 void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
579 using search = algorithm_search<algo_t>;
580 auto& cache = search::cache();
581 auto& wsscache = search::wsscache();
582
583 if (cache.find(args.params, algo)) {
584 return;
585 }
586
587 if (args.params.deterministic && !benchmark) {
588 *algo = search::DEFAULT_ALGO;
589 }
590
591 if (cache.find(args.params, algo)) {
592 // re-check cache since another thread may have benchmarked the algorithm
593 return;
594 }
595
596 auto perfResults = search::findAlgorithm(args);
597 *algo = reinterpret_cast<algo_t&>(perfResults);
598
599 cache.insert(args.params, *algo);
600 wsscache.insert(args.params, perfResults.memory);
601
602 if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
603 c10::hip::HIPCachingAllocator::emptyCache();
604 }
605
606 }
607
608 template<typename algo_t>
chooseAlgorithm(const ConvolutionArgs & args,bool benchmark,algo_t * algo)609 Workspace chooseAlgorithm(
610 const ConvolutionArgs& args,
611 bool benchmark,
612 algo_t* algo)
613 {
614 findAlgorithm(args, benchmark, algo);
615
616 using search = algorithm_search<algo_t>;
617 size_t workspace_size;
618 search::wsscache().find(args.params, &workspace_size);
619 try {
620 return Workspace(workspace_size);
621 } catch (const std::exception& e) {
622 hipGetLastError(); // clear OOM error
623
624 // switch to default algorithm and record it in the cache to prevent
625 // further OOM errors
626 *algo = search::DEFAULT_ALGO;
627 workspace_size = getWorkspaceSize(args, *algo);
628 search::cache().insert(args.params, *algo);
629 search::wsscache().insert(args.params, workspace_size);
630 return Workspace(workspace_size);
631 }
632 }
633
634 template<typename algo_t>
chooseSolution(const ConvolutionArgs & args,uint64_t * solution_id)635 Workspace chooseSolution(const ConvolutionArgs& args, uint64_t* solution_id)
636 {
637 using search = algorithm_search<algo_t>;
638 miopenConvSolution_t solution = search::getSolution(args, false);
639 try {
640 *solution_id = solution.solution_id;
641 return Workspace(solution.workspace_size);
642 } catch (const std::exception& e) {
643 hipGetLastError(); // clear OOM error
644
645 // switch to default algorithm
646 solution = search::getSolution(args, true);
647 *solution_id = solution.solution_id;
648 return Workspace(solution.workspace_size);
649 }
650 }
651
652 // ---------------------------------------------------------------------
653 //
654 // Bias addition
655 //
656 // ---------------------------------------------------------------------
657
658 // In-place!
miopen_convolution_add_bias_(CheckedFrom c,const TensorArg & output,const TensorArg & bias)659 void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const TensorArg& bias)
660 {
661 checkAllSameType(c, {output, bias});
662 checkAllSameGPU(c, {output, bias});
663 checkSize(c, bias, { output->size(output_channels_dim) });
664
665 TensorDescriptor bdesc, odesc;
666
667 auto memory_format = output->suggest_memory_format();
668
669 std::vector<int64_t> shape( output->dim(), 1);
670 shape[output_channels_dim] = -1;
671 at::Tensor bias_contig = bias->reshape(shape).contiguous(memory_format);
672 // Make sure that NC11 strides follow formula
673 bias_contig.resize_(bias_contig.sizes(), memory_format );
674
675 // TODO: Workaround since MIOpen does not support NHWC bias
676 // See #64426
677 output->add_( bias_contig );
678
679 /* MIOpen does not support NHWC bias; Activate once support is added.
680 bdesc.set( bias_contig );
681 odesc.set(*output);
682
683 auto handle = getMiopenHandle();
684 auto dataType = getMiopenDataType(*bias);
685 Constant one(dataType, 1);
686 Constant zero(dataType, 0);
687
688 MIOPEN_CHECK(miopenConvolutionForwardBias(handle, &one, bdesc.desc(), bias->const_data_ptr(),
689 &zero, odesc.desc(), output->data_ptr()));
690 */
691 }
692
693 // see NOTE [ Convolution design ] in src/Aten/native/cudnn/Conv.cpp
694
695
696 // ---------------------------------------------------------------------
697 //
698 // Convolution forward / Transposed convolution backward
699 //
700 // ---------------------------------------------------------------------
701
702 // The raw API directly invokes MIOpen.
703 //
704 // There are a few reasons this should never be directly exposed
705 // via ATen:
706 //
707 // - It takes output as a parameter (this should be computed!)
708 // - It doesn't do input checking
709 // - It doesn't resize output (it is assumed to be correctly sized)
710 //
raw_miopen_convolution_forward_out(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)711 void raw_miopen_convolution_forward_out(
712 const Tensor& output, const Tensor& input, const Tensor& weight,
713 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
714 bool benchmark, bool deterministic) {
715
716 auto dataType = getMiopenDataType(input);
717 miopenConvolutionMode_t c_mode = miopenConvolution;
718
719 ConvolutionArgs args{ input, output, weight };
720 args.handle = getMiopenHandle();
721 setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
722 args.idesc.set(input);
723 args.wdesc.set(weight, input.suggest_memory_format(), 0);
724 args.odesc.set(output);
725 args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
726
727 if (benchmark) {
728 miopenConvFwdAlgorithm_t fwdAlg;
729 Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg);
730
731 Constant one(dataType, 1);
732 Constant zero(dataType, 0);
733
734 MIOPEN_CHECK(miopenConvolutionForward(
735 args.handle,
736 &one, args.idesc.desc(), input.const_data_ptr(),
737 args.wdesc.desc(), weight.const_data_ptr(),
738 args.cdesc.desc(), fwdAlg, &zero,
739 args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size));
740 }
741 else {
742 uint64_t solution_id;
743 Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
744
745 MIOPEN_CHECK(miopenConvolutionForwardImmediate(
746 args.handle,
747 args.wdesc.desc(), weight.const_data_ptr(),
748 args.idesc.desc(), input.const_data_ptr(),
749 args.cdesc.desc(),
750 args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size, solution_id));
751 }
752 }
753
miopen_convolution_forward(CheckedFrom c,const TensorArg & input,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)754 Tensor miopen_convolution_forward(
755 CheckedFrom c,
756 const TensorArg& input, const TensorArg& weight,
757 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
758 bool benchmark, bool deterministic)
759 {
760 checkAllSameType(c, {input, weight});
761 checkAllSameGPU(c, {input, weight});
762
763 auto memory_format = at::MemoryFormat::Contiguous;
764 if (miopen_conv_use_channels_last(*input, *weight)) {
765 memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
766 }
767
768 Tensor output_t = at::detail::empty_cuda(
769 conv_output_size(input->sizes(), weight->sizes(),
770 padding, stride, dilation),
771 input->options().memory_format(memory_format));
772
773 if (output_t.numel() == 0) {
774 return output_t;
775 }
776
777 // Avoid ambiguity of "output" when this is being used as backwards
778 TensorArg output{ output_t, "result", 0 };
779 convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
780
781 // See #4500
782 Tensor weight_contig = weight->contiguous(memory_format);
783 // Make sure that NC11 strides follow formula
784 weight_contig.resize_(weight_contig.sizes(), memory_format);
785 Tensor input_contig = input->contiguous(memory_format);
786 input_contig.resize_(input_contig.sizes(), memory_format);
787
788
789
790 raw_miopen_convolution_forward_out(
791 *output, input_contig, weight_contig,
792 padding, stride, dilation, groups, benchmark, deterministic);
793
794 return *output;
795 }
796
miopen_convolution(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)797 Tensor miopen_convolution(
798 const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt,
799 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
800 int64_t groups, bool benchmark, bool deterministic)
801 {
802 // See [Note: hacky wrapper removal for optional tensor]
803 c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
804 const Tensor& bias_t = *bias_t_maybe_owned;
805
806 TensorArg input { input_t, "input", 1 },
807 weight { weight_t, "weight", 2 },
808 bias { bias_t, "bias", 3 };
809 CheckedFrom c = "miopen_convolution";
810 auto output_t = miopen_convolution_forward(
811 c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
812 if (bias->defined()) {
813 miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
814 }
815 return output_t;
816 }
817
818 //Depthwise Convolutions
raw_miopen_depthwise_convolution_forward_out(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)819 void raw_miopen_depthwise_convolution_forward_out(
820 const Tensor& output, const Tensor& input, const Tensor& weight,
821 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
822 bool benchmark, bool deterministic) {
823
824 auto dataType = getMiopenDataType(input);
825 miopenConvolutionMode_t c_mode = miopenDepthwise;
826
827 ConvolutionArgs args{ input, output, weight };
828 args.handle = getMiopenHandle();
829 setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
830 args.idesc.set(input);
831 args.wdesc.set(weight, input.suggest_memory_format(), 0);
832 args.odesc.set(output);
833 args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
834
835 if (benchmark) {
836 miopenConvFwdAlgorithm_t fwdAlg;
837 Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg);
838
839 Constant one(dataType, 1);
840 Constant zero(dataType, 0);
841
842 MIOPEN_CHECK(miopenConvolutionForward(
843 args.handle,
844 &one, args.idesc.desc(), input.const_data_ptr(),
845 args.wdesc.desc(), weight.const_data_ptr(),
846 args.cdesc.desc(), fwdAlg, &zero,
847 args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size));
848 }
849 else {
850 uint64_t solution_id;
851 Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
852
853 MIOPEN_CHECK(miopenConvolutionForwardImmediate(
854 args.handle,
855 args.wdesc.desc(), weight.const_data_ptr(),
856 args.idesc.desc(), input.const_data_ptr(),
857 args.cdesc.desc(),
858 args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size, solution_id));
859 }
860 }
861
miopen_depthwise_convolution_forward(CheckedFrom c,const TensorArg & input,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)862 Tensor miopen_depthwise_convolution_forward(
863 CheckedFrom c,
864 const TensorArg& input, const TensorArg& weight,
865 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
866 bool benchmark, bool deterministic)
867 {
868 checkAllSameType(c, {input, weight});
869 checkAllSameGPU(c, {input, weight});
870
871 auto memory_format = at::MemoryFormat::Contiguous;
872 if (miopen_conv_use_channels_last(*input, *weight)) {
873 memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
874 }
875
876 Tensor output_t = at::detail::empty_cuda(
877 conv_output_size(input->sizes(), weight->sizes(),
878 padding, stride, dilation),
879 input->options().memory_format(memory_format));
880
881 TensorArg output{ output_t, "result", 0 };
882 convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
883
884 // See #4500
885 Tensor weight_contig = weight->contiguous(memory_format);
886 // Make sure that NC11 strides follow formula
887 weight_contig.resize_(weight_contig.sizes(), memory_format);
888 Tensor input_contig = input->contiguous(memory_format);
889 input_contig.resize_(input_contig.sizes(), memory_format);
890
891 raw_miopen_depthwise_convolution_forward_out(
892 *output, input_contig, weight_contig,
893 padding, stride, dilation, groups, benchmark, deterministic);
894
895 return *output;
896 }
897
miopen_depthwise_convolution(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)898 Tensor miopen_depthwise_convolution(
899 const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt,
900 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
901 int64_t groups, bool benchmark, bool deterministic)
902 {
903 // See [Note: hacky wrapper removal for optional tensor]
904 c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
905 const Tensor& bias_t = *bias_t_maybe_owned;
906
907 TensorArg input { input_t, "input", 1 },
908 weight { weight_t, "weight", 2 },
909 bias { bias_t, "bias", 3 };
910 CheckedFrom c = "miopen_depthwise_convolution";
911 auto output_t = miopen_depthwise_convolution_forward(
912 c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
913 if (bias->defined()) {
914 miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
915 }
916 return output_t;
917 }
918
919 // ---------------------------------------------------------------------
920 //
921 // Convolution backward (bias)
922 //
923 // ---------------------------------------------------------------------
924
miopen_convolution_backward_bias(const Tensor & grad_output_t)925 Tensor miopen_convolution_backward_bias(
926 const Tensor& grad_output_t)
927 {
928 TensorArg grad_output{ grad_output_t, "grad_output", 1 };
929
930 // TODO: Workaround since MIOpen does not support NHWC bias
931 // See #64426
932 std::vector<int64_t> discard_dims;
933 for( int i = 0; i < grad_output_t.dim(); i++ ) {
934 if(i != output_channels_dim ) {
935 discard_dims.push_back(i);
936 }
937 }
938
939 Tensor outputBias = at::squeeze( at::sum(grad_output_t, discard_dims, true) );
940 if( outputBias.dim() == 0 ) {
941 // always return a tensor of shape [_]
942 return outputBias.unsqueeze(0);
943 }
944 else {
945 return outputBias;
946 }
947
948 /* MIOpen does not support NHWC bias. Activate once support is added.
949 auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options());
950
951 TensorArg grad_bias{ grad_bias_t, "result", 0 };
952
953 TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}),
954 static_cast<size_t>(grad_output->dim())};
955 TensorDescriptor odesc{*grad_output};
956
957 auto handle = getMiopenHandle();
958 auto dataType = getMiopenDataType(*grad_bias);
959 Constant one(dataType, 1);
960 Constant zero(dataType, 0);
961
962 MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(),
963 &zero, bdesc.desc(), grad_bias->data_ptr()));
964 return *grad_bias;
965 */
966 }
967
968 // ---------------------------------------------------------------------
969 //
970 // Convolution backward (weight)
971 //
972 // ---------------------------------------------------------------------
973
raw_miopen_convolution_backward_weight_out(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)974 void raw_miopen_convolution_backward_weight_out(
975 const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
976 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
977 bool benchmark, bool deterministic) {
978
979 auto dataType = getMiopenDataType(input);
980 miopenConvolutionMode_t c_mode = miopenConvolution;
981
982 ConvolutionArgs args{ input, grad_output, grad_weight };
983 args.handle = getMiopenHandle();
984 setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
985 args.idesc.set(input);
986 args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
987 args.odesc.set(grad_output);
988 args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
989
990 if (benchmark) {
991 miopenConvBwdWeightsAlgorithm_t bwdFilterAlg;
992 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg);
993
994 Constant one(dataType, 1);
995 Constant zero(dataType, 0);
996
997 MIOPEN_CHECK(miopenConvolutionBackwardWeights(
998 args.handle,
999 &one, args.odesc.desc(), grad_output.const_data_ptr(),
1000 args.idesc.desc(), input.const_data_ptr(),
1001 args.cdesc.desc(), bwdFilterAlg, &zero,
1002 args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size));
1003 }
1004 else {
1005 uint64_t solution_id;
1006 Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
1007
1008 MIOPEN_CHECK(miopenConvolutionBackwardWeightsImmediate(
1009 args.handle,
1010 args.odesc.desc(), grad_output.const_data_ptr(),
1011 args.idesc.desc(), input.const_data_ptr(),
1012 args.cdesc.desc(),
1013 args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size, solution_id));
1014 }
1015 }
1016
1017 //Depthwise backward weights.
raw_miopen_depthwise_convolution_backward_weight_out(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1018 void raw_miopen_depthwise_convolution_backward_weight_out(
1019 const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
1020 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1021 bool benchmark, bool deterministic) {
1022
1023 auto dataType = getMiopenDataType(input);
1024 miopenConvolutionMode_t c_mode = miopenDepthwise;
1025
1026 ConvolutionArgs args{ input, grad_output, grad_weight };
1027 args.handle = getMiopenHandle();
1028 setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
1029 args.idesc.set(input);
1030 args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
1031 args.odesc.set(grad_output);
1032 args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
1033
1034 if (benchmark) {
1035 miopenConvBwdWeightsAlgorithm_t bwdFilterAlg;
1036 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg);
1037
1038 Constant one(dataType, 1);
1039 Constant zero(dataType, 0);
1040
1041 MIOPEN_CHECK(miopenConvolutionBackwardWeights(
1042 args.handle,
1043 &one, args.odesc.desc(), grad_output.const_data_ptr(),
1044 args.idesc.desc(), input.const_data_ptr(),
1045 args.cdesc.desc(), bwdFilterAlg, &zero,
1046 args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size));
1047 }
1048 else {
1049 uint64_t solution_id;
1050 Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
1051
1052 MIOPEN_CHECK(miopenConvolutionBackwardWeightsImmediate(
1053 args.handle,
1054 args.odesc.desc(), grad_output.const_data_ptr(),
1055 args.idesc.desc(), input.const_data_ptr(),
1056 args.cdesc.desc(),
1057 args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size, solution_id));
1058 }
1059 }
1060
miopen_depthwise_convolution_backward_weight(CheckedFrom c,IntArrayRef weight_size,const TensorArg & grad_output,const TensorArg & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1061 Tensor miopen_depthwise_convolution_backward_weight(
1062 CheckedFrom c,
1063 IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input,
1064 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1065 bool benchmark, bool deterministic)
1066 {
1067
1068 checkAllSameType(c, {grad_output, input});
1069 checkAllSameGPU(c, {grad_output, input});
1070
1071 auto memory_format = at::MemoryFormat::Contiguous;
1072 if (miopen_conv_use_channels_last(*input, *grad_output)) {
1073 memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1074 }
1075
1076 Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
1077 // Make sure that NC11 strides follow formula
1078 grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
1079 TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
1080
1081 Tensor input_contig_t = input->contiguous(memory_format);
1082 input_contig_t.resize_(input_contig_t.sizes(), memory_format);
1083 TensorArg input_contig{ input_contig_t, "input", 2};
1084
1085 auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
1086
1087 // For uniformity with everything else, although it seems grad_weight
1088 // would be unambiguous too.
1089 TensorArg grad_weight{ grad_weight_t, "result", 0 };
1090 convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
1091
1092 raw_miopen_depthwise_convolution_backward_weight_out(
1093 *grad_weight, *grad_output_contig, *input_contig,
1094 padding, stride, dilation, groups, benchmark, deterministic);
1095
1096 return grad_weight_t;
1097 }
1098
miopen_depthwise_convolution_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1099 Tensor miopen_depthwise_convolution_backward_weight(
1100 IntArrayRef weight_size,
1101 const Tensor& grad_output_t,
1102 const Tensor& input_t,
1103 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1104 bool benchmark, bool deterministic)
1105 {
1106 TensorArg grad_output{ grad_output_t, "grad_output", 1 },
1107 input{ input_t, "input", 2 };
1108 return miopen_depthwise_convolution_backward_weight(
1109 "miopen_depthwise_convolution_backward_weight",
1110 weight_size, grad_output, input,
1111 padding, stride, dilation, groups, benchmark, deterministic);
1112 }
1113
miopen_convolution_backward_weight(CheckedFrom c,IntArrayRef weight_size,const TensorArg & grad_output,const TensorArg & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1114 Tensor miopen_convolution_backward_weight(
1115 CheckedFrom c,
1116 IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input,
1117 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1118 bool benchmark, bool deterministic)
1119 {
1120
1121 checkAllSameType(c, {grad_output, input});
1122 checkAllSameGPU(c, {grad_output, input});
1123
1124 auto memory_format = at::MemoryFormat::Contiguous;
1125 if (miopen_conv_use_channels_last(*input, *grad_output)) {
1126 memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1127 }
1128
1129 Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
1130 // Make sure that NC11 strides follow formula
1131 grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
1132 TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
1133
1134 Tensor input_contig_t = input->contiguous(memory_format);
1135 input_contig_t.resize_(input_contig_t.sizes(), memory_format);
1136 TensorArg input_contig{ input_contig_t, "input", 2};
1137
1138 auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
1139
1140 // For uniformity with everything else, although it seems grad_weight
1141 // would be unambiguous too.
1142 TensorArg grad_weight{ grad_weight_t, "result", 0 };
1143 convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
1144
1145 raw_miopen_convolution_backward_weight_out(
1146 *grad_weight, *grad_output_contig, *input_contig,
1147 padding, stride, dilation, groups, benchmark, deterministic);
1148
1149 return grad_weight_t;
1150 }
1151
miopen_convolution_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1152 Tensor miopen_convolution_backward_weight(
1153 IntArrayRef weight_size,
1154 const Tensor& grad_output_t,
1155 const Tensor& input_t,
1156 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1157 bool benchmark, bool deterministic)
1158 {
1159 TensorArg grad_output{ grad_output_t, "grad_output", 1 },
1160 input{ input_t, "input", 2 };
1161 return miopen_convolution_backward_weight(
1162 "miopen_convolution_backward_weight",
1163 weight_size, grad_output, input,
1164 padding, stride, dilation, groups, benchmark, deterministic);
1165 }
1166
miopen_convolution_transpose_backward_input(const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1167 Tensor miopen_convolution_transpose_backward_input(
1168 const Tensor& grad_output_t, const Tensor& weight_t,
1169 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
1170 int64_t groups, bool benchmark, bool deterministic)
1171 {
1172 TensorArg grad_output { grad_output_t, "grad_output", 1 },
1173 weight { weight_t, "weight", 2 };
1174 return miopen_convolution_forward(
1175 "miopen_convolution_transpose_backward_input",
1176 grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
1177 }
1178
miopen_convolution_transpose_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1179 Tensor miopen_convolution_transpose_backward_weight(
1180 IntArrayRef weight_size,
1181 const Tensor& grad_output_t,
1182 const Tensor& input_t,
1183 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1184 bool benchmark, bool deterministic)
1185 {
1186 TensorArg grad_output{ grad_output_t, "grad_output", 1 },
1187 input{ input_t, "input", 2 };
1188 return miopen_convolution_backward_weight(
1189 "miopen_convolution_backward_weight",
1190 weight_size, input, grad_output,
1191 padding, stride, dilation, groups, benchmark, deterministic);
1192 }
1193
miopen_convolution_transpose_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)1194 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_transpose_backward(
1195 const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
1196 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1197 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
1198
1199 Tensor grad_output = grad_output_t.contiguous();
1200
1201 Tensor grad_input, grad_weight, grad_bias;
1202 if (output_mask[0]) {
1203 grad_input = miopen_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
1204 }
1205 if (output_mask[1]) {
1206 grad_weight = miopen_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
1207 }
1208 if (output_mask[2]) {
1209 grad_bias = miopen_convolution_backward_bias(grad_output);
1210 }
1211
1212 return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
1213 }
1214
1215 // ---------------------------------------------------------------------
1216 //
1217 // Convolution backward / Transposed convolution forward
1218 //
1219 // ---------------------------------------------------------------------
1220
raw_miopen_convolution_backward_input_out(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1221 void raw_miopen_convolution_backward_input_out(
1222 const at::Tensor& grad_input,
1223 const at::Tensor& grad_output,
1224 const at::Tensor& weight,
1225 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1226 bool benchmark, bool deterministic) {
1227
1228 auto dataType = getMiopenDataType(grad_output);
1229 miopenConvolutionMode_t c_mode = miopenConvolution;
1230
1231 ConvolutionArgs args{ grad_input, grad_output, weight };
1232 args.handle = getMiopenHandle();
1233 setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic);
1234 args.idesc.set(grad_input);
1235 args.wdesc.set(weight, grad_output.suggest_memory_format(), 0);
1236 args.odesc.set(grad_output);
1237 args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
1238
1239 if (benchmark) {
1240 miopenConvBwdDataAlgorithm_t bwdDataAlg;
1241 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg);
1242
1243 Constant one(dataType, 1);
1244 Constant zero(dataType, 0);
1245
1246 MIOPEN_CHECK(miopenConvolutionBackwardData(
1247 args.handle,
1248 &one, args.odesc.desc(), grad_output.const_data_ptr(),
1249 args.wdesc.desc(), weight.const_data_ptr(),
1250 args.cdesc.desc(), bwdDataAlg, &zero,
1251 args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size));
1252 }
1253 else {
1254 uint64_t solution_id;
1255 Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);
1256
1257 MIOPEN_CHECK(miopenConvolutionBackwardDataImmediate(
1258 args.handle,
1259 args.odesc.desc(), grad_output.const_data_ptr(),
1260 args.wdesc.desc(), weight.const_data_ptr(),
1261 args.cdesc.desc(),
1262 args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size, solution_id));
1263 }
1264 }
1265
1266 // see NOTE [ Backward vs transpose convolutions ] in src/Aten/native/cudnn/Conv.cpp
1267
miopen_convolution_backward_input(CheckedFrom c,IntArrayRef input_size,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1268 Tensor miopen_convolution_backward_input(
1269 CheckedFrom c,
1270 IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight,
1271 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1272 bool benchmark, bool deterministic)
1273 {
1274 checkAllSameType(c, {grad_output, weight});
1275 checkAllSameGPU(c, {grad_output, weight});
1276
1277 auto memory_format = at::MemoryFormat::Contiguous;
1278 if (miopen_conv_use_channels_last(*grad_output, *weight)) {
1279 memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1280 }
1281
1282 Tensor grad_input_t = at::detail::empty_cuda(
1283 input_size, grad_output->options().memory_format(memory_format));
1284
1285 // Avoid "grad_input" when this is being used as transposed convolution
1286 TensorArg grad_input{ grad_input_t, "result", 0 };
1287 convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
1288
1289 // See #4500
1290 Tensor weight_contig = weight->contiguous(memory_format);
1291 // Make sure that NC11 strides follow formula
1292 weight_contig.resize_(weight_contig.sizes(), memory_format);
1293
1294 Tensor grad_output_contig = grad_output->contiguous(memory_format);
1295 grad_output_contig.resize_(grad_output_contig.sizes(), memory_format);
1296
1297 raw_miopen_convolution_backward_input_out(
1298 *grad_input, grad_output_contig, weight_contig,
1299 padding, stride, dilation, groups, benchmark, deterministic);
1300
1301 return *grad_input;
1302 }
1303
miopen_convolution_transpose_forward(CheckedFrom c,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1304 Tensor miopen_convolution_transpose_forward(
1305 CheckedFrom c,
1306 const TensorArg& grad_output, const TensorArg& weight,
1307 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1308 bool benchmark, bool deterministic)
1309 {
1310 auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(),
1311 padding, output_padding, stride, dilation, groups);
1312 return miopen_convolution_backward_input(c, input_size, grad_output, weight,
1313 padding, stride, dilation, groups, benchmark, deterministic);
1314 }
1315
miopen_convolution_backward_input(IntArrayRef input_size,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1316 Tensor miopen_convolution_backward_input(
1317 IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
1318 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1319 bool benchmark, bool deterministic)
1320 {
1321 TensorArg grad_output{ grad_output_t, "grad_output", 1 },
1322 weight{ weight_t, "weight", 2 };
1323 return miopen_convolution_backward_input(
1324 "miopen_convolution_backward_input",
1325 input_size, grad_output, weight,
1326 padding, stride, dilation, groups, benchmark, deterministic);
1327 }
1328
1329 //Depthwise convolutions backward data.
raw_miopen_depthwise_convolution_backward_input_out(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1330 void raw_miopen_depthwise_convolution_backward_input_out(
1331 const at::Tensor& grad_input,
1332 const at::Tensor& grad_output,
1333 const at::Tensor& weight,
1334 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1335 bool benchmark, bool deterministic) {
1336
1337 auto dataType = getMiopenDataType(grad_output);
1338 miopenConvolutionMode_t c_mode = miopenDepthwise;
1339
1340 ConvolutionArgs args{ grad_input, grad_output, weight };
1341 args.handle = getMiopenHandle();
1342 setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic);
1343 args.idesc.set(grad_input);
1344 args.wdesc.set(weight, grad_output.suggest_memory_format(), 0);
1345 args.odesc.set(grad_output);
1346 args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
1347
1348 if (benchmark) {
1349 miopenConvBwdDataAlgorithm_t bwdDataAlg;
1350 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg);
1351
1352 Constant one(dataType, 1);
1353 Constant zero(dataType, 0);
1354
1355 MIOPEN_CHECK(miopenConvolutionBackwardData(
1356 args.handle,
1357 &one, args.odesc.desc(), grad_output.const_data_ptr(),
1358 args.wdesc.desc(), weight.const_data_ptr(),
1359 args.cdesc.desc(), bwdDataAlg, &zero,
1360 args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size));
1361 }
1362 else {
1363 uint64_t solution_id;
1364 Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);
1365
1366 MIOPEN_CHECK(miopenConvolutionBackwardDataImmediate(
1367 args.handle,
1368 args.odesc.desc(), grad_output.const_data_ptr(),
1369 args.wdesc.desc(), weight.const_data_ptr(),
1370 args.cdesc.desc(),
1371 args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size, solution_id));
1372 }
1373 }
1374
miopen_depthwise_convolution_backward_input(CheckedFrom c,IntArrayRef input_size,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1375 Tensor miopen_depthwise_convolution_backward_input(
1376 CheckedFrom c,
1377 IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight,
1378 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1379 bool benchmark, bool deterministic)
1380 {
1381 checkAllSameType(c, {grad_output, weight});
1382 checkAllSameGPU(c, {grad_output, weight});
1383
1384 auto memory_format = at::MemoryFormat::Contiguous;
1385 if (miopen_conv_use_channels_last(*grad_output, *weight)) {
1386 memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1387 }
1388
1389 Tensor grad_input_t = at::detail::empty_cuda(
1390 input_size, grad_output->options().memory_format(memory_format));
1391
1392 TensorArg grad_input{ grad_input_t, "result", 0 };
1393 convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
1394
1395 // See #4500
1396 Tensor weight_contig = weight->contiguous(memory_format);
1397 // Make sure that NC11 strides follow formula
1398 weight_contig.resize_(weight_contig.sizes(), memory_format);
1399
1400 Tensor grad_output_contig = grad_output->contiguous(memory_format);
1401 grad_output_contig.resize_(grad_output_contig.sizes(), memory_format);
1402
1403 raw_miopen_depthwise_convolution_backward_input_out(
1404 *grad_input, grad_output_contig, weight_contig,
1405 padding, stride, dilation, groups, benchmark, deterministic);
1406
1407 return *grad_input;
1408 }
1409
miopen_depthwise_convolution_backward_input(IntArrayRef input_size,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1410 Tensor miopen_depthwise_convolution_backward_input(
1411 IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
1412 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1413 bool benchmark, bool deterministic)
1414 {
1415 TensorArg grad_output{ grad_output_t, "grad_output", 1 },
1416 weight{ weight_t, "weight", 2 };
1417 return miopen_depthwise_convolution_backward_input(
1418 "miopen_depthwise_convolution_backward_input",
1419 input_size, grad_output, weight,
1420 padding, stride, dilation, groups, benchmark, deterministic);
1421 }
1422
miopen_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)1423 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_backward(
1424 const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
1425 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1426 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
1427
1428 Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
1429
1430 Tensor grad_input, grad_weight, grad_bias;
1431 if (output_mask[0]) {
1432 grad_input = miopen_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
1433 }
1434 if (output_mask[1]) {
1435 grad_weight = miopen_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
1436 }
1437 if (output_mask[2]) {
1438 grad_bias = miopen_convolution_backward_bias(grad_output);
1439 }
1440
1441 return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
1442 }
1443
miopen_depthwise_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,std::array<bool,3> output_mask)1444 std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_depthwise_convolution_backward(
1445 const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
1446 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1447 bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
1448
1449 Tensor grad_output = grad_output_t.contiguous();
1450
1451 Tensor grad_input, grad_weight, grad_bias;
1452 if (output_mask[0]) {
1453 grad_input = miopen_depthwise_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
1454 }
1455 if (output_mask[1]) {
1456 grad_weight = miopen_depthwise_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
1457 }
1458 if (output_mask[2]) {
1459 grad_bias = miopen_convolution_backward_bias(grad_output);
1460 }
1461
1462 return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
1463 }
1464
miopen_convolution_transpose(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1465 Tensor miopen_convolution_transpose(
1466 const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt,
1467 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
1468 int64_t groups, bool benchmark, bool deterministic)
1469 {
1470 // See [Note: hacky wrapper removal for optional tensor]
1471 c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
1472 const Tensor& bias_t = *bias_t_maybe_owned;
1473
1474 TensorArg input { input_t, "input", 1 },
1475 weight { weight_t, "weight", 2 },
1476 bias { bias_t, "bias", 3 };
1477 CheckedFrom c = "miopen_convolution_transpose";
1478 auto output_t = miopen_convolution_transpose_forward(
1479 c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic);
1480 if (bias->defined()) {
1481 miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
1482 }
1483 return output_t;
1484 }
1485
1486 // MIOpen fused convolution bias activation forward
raw_miopen_convolution_relu_out(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic)1487 void raw_miopen_convolution_relu_out(
1488 const Tensor& output,
1489 const Tensor& input,
1490 const Tensor& weight,
1491 const Tensor& bias,
1492 IntArrayRef stride,
1493 IntArrayRef padding,
1494 IntArrayRef dilation,
1495 int64_t groups,
1496 bool benchmark,
1497 bool deterministic) {
1498
1499 auto dataType = getMiopenDataType(input);
1500 miopenConvolutionMode_t c_mode = miopenConvolution;
1501
1502 ConvolutionArgs args{ input, output, weight };
1503 args.handle = getMiopenHandle();
1504 setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
1505 args.idesc.set(input);
1506 args.wdesc.set(weight, input.suggest_memory_format(), 0);
1507 args.odesc.set(output);
1508 args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, deterministic);
1509
1510 TensorDescriptor bdesc;
1511 bdesc.set(bias.expand({1, bias.size(0)}), output.dim());
1512
1513 // Create the fusion plan
1514 miopenFusionPlanDescriptor_t fusePlanDesc;
1515 miopenFusionOpDescriptor_t convoOp;
1516 miopenFusionOpDescriptor_t biasOp;
1517 miopenFusionOpDescriptor_t activOp;
1518 MIOPEN_CHECK(miopenCreateFusionPlan(&fusePlanDesc, miopenVerticalFusion, args.idesc.desc()));
1519 MIOPEN_CHECK(miopenCreateOpConvForward(fusePlanDesc, &convoOp, args.cdesc.desc(), args.wdesc.desc()));
1520 MIOPEN_CHECK(miopenCreateOpBiasForward(fusePlanDesc, &biasOp, bdesc.desc()));
1521 MIOPEN_CHECK(miopenCreateOpActivationForward(fusePlanDesc, &activOp, miopenActivationRELU));
1522
1523 // compile fusion plan
1524 MIOPEN_CHECK(miopenCompileFusionPlan(args.handle, fusePlanDesc));
1525
1526 // Set the Args
1527 float alpha = static_cast<float>(1);
1528 float beta = static_cast<float>(0);
1529 float activ_alpha = static_cast<float>(0);
1530 float activ_beta = static_cast<float>(0);
1531 float activ_gamma = static_cast<float>(0);
1532 miopenOperatorArgs_t fusionArgs;
1533 MIOPEN_CHECK(miopenCreateOperatorArgs(&fusionArgs));
1534 MIOPEN_CHECK(miopenSetOpArgsConvForward(fusionArgs, convoOp, &alpha, &beta, weight.const_data_ptr()));
1535 MIOPEN_CHECK(miopenSetOpArgsBiasForward(fusionArgs, biasOp, &alpha, &beta, bias.const_data_ptr()));
1536 MIOPEN_CHECK(miopenSetOpArgsActivForward(fusionArgs, activOp, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
1537
1538 miopenExecuteFusionPlan(args.handle, fusePlanDesc, args.idesc.desc(), input.const_data_ptr(), args.odesc.desc(), output.data_ptr(), fusionArgs);
1539
1540 // Cleanup
1541 miopenDestroyFusionPlan(fusePlanDesc);
1542 }
1543
self_or_new_memory_format(at::Tensor & self,at::MemoryFormat memory_format)1544 static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) {
1545 if (self.is_contiguous(memory_format)) {
1546 return self;
1547 }
1548 return at::empty_like(self, self.options(), memory_format);
1549 }
1550
miopen_convolution_add_relu(const Tensor & input,const Tensor & weight,const Tensor & z,const std::optional<Scalar> & alpha,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)1551 Tensor miopen_convolution_add_relu(
1552 const Tensor& input,
1553 const Tensor& weight,
1554 const Tensor& z,
1555 const std::optional<Scalar>& alpha,
1556 const std::optional<Tensor>& bias,
1557 IntArrayRef stride,
1558 IntArrayRef padding,
1559 IntArrayRef dilation,
1560 int64_t groups) {
1561
1562 // MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function:
1563 // y = act ( alpha1 * conv(x) + alpha2 * z + bias )
1564
1565 auto memory_format = input.suggest_memory_format();
1566
1567 auto& ctx = at::globalContext();
1568 bool benchmark = ctx.benchmarkCuDNN();
1569
1570 TensorArg input_arg { input, "input", 1 },
1571 weight_arg { weight, "weight", 2 };
1572 auto output = miopen_convolution_forward(
1573 "miopen_convolution_add_relu",
1574 input_arg,
1575 weight_arg,
1576 padding,
1577 stride,
1578 dilation,
1579 groups,
1580 benchmark,
1581 false // deterministic
1582 );
1583
1584 auto contig_output = self_or_new_memory_format(output, memory_format);
1585
1586 if (!output.is_same(contig_output)) {
1587 contig_output.copy_(output);
1588 }
1589
1590 auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
1591 auto _bias = bias.has_value()
1592 ? bias.value()
1593 : at::zeros(
1594 {contig_output.size(1)},
1595 optTypeMetaToScalarType(contig_output.options().dtype_opt()),
1596 contig_output.options().layout_opt(),
1597 contig_output.options().device_opt(),
1598 contig_output.options().pinned_memory_opt());
1599
1600 at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input.dim(), _bias).add(z, _alpha);
1601 contig_output.add_(alpha_mul_z_add_bias);
1602 contig_output.relu_();
1603
1604 return contig_output;
1605 }
1606
miopen_convolution_relu(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)1607 Tensor miopen_convolution_relu(
1608 const Tensor& input,
1609 const Tensor& weight,
1610 const std::optional<Tensor>& bias,
1611 IntArrayRef stride,
1612 IntArrayRef padding,
1613 IntArrayRef dilation,
1614 int64_t groups) {
1615
1616 auto memory_format = input.suggest_memory_format();
1617
1618 auto& ctx = at::globalContext();
1619 bool benchmark = ctx.benchmarkCuDNN();
1620
1621 // MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d
1622 if (input.suggest_memory_format() == at::MemoryFormat::Contiguous
1623 && input.scalar_type() == at::kFloat
1624 && input.ndimension() == 4) {
1625
1626 // FuseFrozenConvAddRelu performs some tensor shape checking
1627 Tensor output_t = at::detail::empty_cuda(
1628 conv_output_size(
1629 input.sizes(), weight.sizes(), padding, stride, dilation),
1630 input.options().memory_format(input.suggest_memory_format()));
1631 if (output_t.numel() == 0) {
1632 return output_t;
1633 }
1634
1635 auto _bias = bias.has_value()
1636 ? bias.value()
1637 : at::zeros(
1638 {output_t.size(1)},
1639 optTypeMetaToScalarType(output_t.options().dtype_opt()),
1640 output_t.options().layout_opt(),
1641 output_t.options().device_opt(),
1642 output_t.options().pinned_memory_opt());
1643
1644 raw_miopen_convolution_relu_out(
1645 output_t,
1646 input,
1647 weight,
1648 _bias,
1649 stride,
1650 padding,
1651 dilation,
1652 groups,
1653 benchmark, // benchmark
1654 false // deterministic
1655 );
1656
1657 return output_t;
1658 }
1659 else {
1660 // fallback
1661
1662 TensorArg input_arg { input, "input", 1 },
1663 weight_arg { weight, "weight", 2 };
1664 auto output = miopen_convolution_forward(
1665 "miopen_convolution_relu",
1666 input_arg,
1667 weight_arg,
1668 padding,
1669 stride,
1670 dilation,
1671 groups,
1672 benchmark,
1673 false // deterministic
1674 );
1675
1676 auto contig_output = self_or_new_memory_format(output, memory_format);
1677
1678 if (!output.is_same(contig_output)) {
1679 contig_output.copy_(output);
1680 }
1681
1682 auto _bias = bias.has_value()
1683 ? bias.value()
1684 : at::zeros(
1685 {contig_output.size(1)},
1686 optTypeMetaToScalarType(contig_output.options().dtype_opt()),
1687 contig_output.options().layout_opt(),
1688 contig_output.options().device_opt(),
1689 contig_output.options().pinned_memory_opt());
1690
1691 at::Tensor reshaped_bias = at::native::reshape_bias(input.dim(), _bias);
1692 contig_output.add_(reshaped_bias);
1693 contig_output.relu_();
1694
1695 return contig_output;
1696 }
1697 }
1698
1699 REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward);
1700 REGISTER_CUDA_DISPATCH(miopen_convolution_transpose_backward_stub, &miopen_convolution_transpose_backward);
1701 REGISTER_CUDA_DISPATCH(miopen_depthwise_convolution_backward_stub, &miopen_depthwise_convolution_backward);
1702
1703 }} // namespace
1704
1705 #endif
1706