1 #include <vector>
2
3 #include <ATen/core/ATen_fwd.h>
4 #include <ATen/core/interned_strings.h>
5 #include <ATen/ops/full.h>
6 #include <ATen/ops/neg.h>
7 #include <c10/core/Scalar.h>
8 #include <c10/util/Exception.h>
9 #include <optional>
10 #include <ATen/native/utils/ParamUtils.h>
11 #include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
12 #include <torch/library.h>
13 #include <ATen/native/ConvUtils.h>
14
15 using namespace dnnl;
16 using namespace at::native;
17 using namespace at::native::onednn;
18
19 namespace at::native {
20 namespace xpu {
21 namespace impl {
22
23 struct ConvParams {
24 std::vector<int64_t> stride;
25 std::vector<int64_t> padding;
26 std::vector<int64_t> dilation;
27 bool transposed;
28 std::vector<int64_t> output_padding;
29 int groups;
30 bool benchmark;
31 bool deterministic;
32
33 bool is_strided() const;
34 bool is_dilated() const;
35 bool is_padded() const;
36 bool is_output_padding_neg() const;
37 bool is_output_padding_big() const;
38 bool is_padding_neg() const;
39 bool is_stride_nonpos() const;
40 void view1d_as_2d();
41 bool use_cpu_depthwise3x3_winograd(
42 const at::Tensor& input,
43 const at::Tensor& weight) const;
44 bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
45 };
46
operator <<(std::ostream & out,const ConvParams & params)47 std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
48 out << "ConvParams {"
49 << " stride = " << IntArrayRef{params.stride}
50 << " padding = " << IntArrayRef{params.padding}
51 << " dilation = " << IntArrayRef{params.dilation}
52 << " transposed = " << params.transposed
53 << " output_padding = " << IntArrayRef{params.output_padding}
54 << " groups = " << params.groups << " benchmark = " << params.benchmark
55 << " deterministic = " << params.deterministic << "}";
56 return out;
57 }
58
is_strided() const59 bool ConvParams::is_strided() const {
60 bool is_strided = false;
61 for (int s : stride) {
62 is_strided |= (s != 1);
63 }
64 return is_strided;
65 }
66
is_dilated() const67 bool ConvParams::is_dilated() const {
68 bool is_dilated = false;
69 for (int d : dilation) {
70 is_dilated |= (d != 1);
71 }
72 return is_dilated;
73 }
74
is_padded() const75 bool ConvParams::is_padded() const {
76 bool is_padded = false;
77 for (int p : padding) {
78 is_padded |= (p != 0);
79 }
80 return is_padded;
81 }
82
is_output_padding_neg() const83 bool ConvParams::is_output_padding_neg() const {
84 bool is_non_neg = false;
85 for (int p : output_padding) {
86 is_non_neg |= (p < 0);
87 }
88 return is_non_neg;
89 }
90
is_output_padding_big() const91 bool ConvParams::is_output_padding_big() const {
92 bool is_big = false;
93 for (size_t i = 0; i < output_padding.size(); i++) {
94 is_big |=
95 (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]);
96 }
97 return is_big;
98 }
99
is_padding_neg() const100 bool ConvParams::is_padding_neg() const {
101 bool is_non_neg = false;
102 for (int p : padding) {
103 is_non_neg |= (p < 0);
104 }
105 return is_non_neg;
106 }
107
is_stride_nonpos() const108 bool ConvParams::is_stride_nonpos() const {
109 bool is_nonpos = false;
110 for (int s : stride) {
111 is_nonpos |= (s <= 0);
112 }
113 return is_nonpos;
114 }
115
view1d_as_2d()116 void ConvParams::view1d_as_2d() {
117 if (stride.size() == 1) {
118 stride.insert(stride.begin(), 1);
119 padding.insert(padding.begin(), 0);
120 dilation.insert(dilation.begin(), 1);
121 output_padding.insert(output_padding.begin(), 0);
122 }
123 }
124
use_cpu_depthwise3x3_winograd(const at::Tensor & input,const at::Tensor & weight) const125 bool ConvParams::use_cpu_depthwise3x3_winograd(
126 const at::Tensor& input,
127 const at::Tensor& weight) const {
128 return false;
129 }
130
is_depthwise(const at::Tensor & input,const at::Tensor & weight) const131 bool ConvParams::is_depthwise(const at::Tensor& input, const at::Tensor& weight)
132 const {
133 return !transposed && input.ndimension() == 4 && input.size(1) == groups &&
134 groups > 1 && // no point if there is only a single group
135 weight.size(0) % input.size(1) ==
136 0; // output channels must be a multiple of input channels
137 }
138
check_shape_forward(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,const ConvParams & params,bool input_is_mkldnn)139 static void check_shape_forward(
140 const at::Tensor& input,
141 const at::Tensor& weight,
142 const at::Tensor& bias,
143 const ConvParams& params,
144 bool input_is_mkldnn) {
145 int64_t k = input.ndimension();
146 int64_t weight_dim = weight.ndimension();
147 std::vector<int64_t> weight_sizes(weight_dim);
148 if ((weight_dim == k + 1) && input_is_mkldnn) {
149 weight_sizes[0] = weight.size(0) * weight.size(1);
150 std::copy_n(weight.sizes().cbegin() + 2, k - 1, weight_sizes.begin() + 1);
151 weight_dim = k;
152 } else {
153 std::copy_n(weight.sizes().cbegin(), weight_dim, weight_sizes.begin());
154 }
155 int64_t groups = params.groups;
156 auto padding = params.padding;
157 auto output_padding = params.output_padding;
158 auto stride = params.stride;
159 auto dilation = params.dilation;
160 bool transposed = params.transposed;
161
162 TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
163 TORCH_CHECK(
164 !params.is_output_padding_neg(),
165 "negative output_padding is not supported");
166 TORCH_CHECK(
167 !params.is_stride_nonpos(), "non-positive stride is not supported");
168
169 TORCH_CHECK(
170 weight_dim == k,
171 "Expected ",
172 weight_dim,
173 "-dimensional input for ",
174 weight_dim,
175 "-dimensional weight ",
176 weight_sizes,
177 ", but got ",
178 k,
179 "-dimensional input of size ",
180 input.sizes(),
181 " instead");
182 TORCH_CHECK(
183 weight_sizes[0] >= groups,
184 "Given groups=",
185 groups,
186 ", expected weight to be at least ",
187 groups,
188 " at dimension 0, but got weight of size ",
189 weight_sizes,
190 " instead");
191 TORCH_CHECK(
192 weight_sizes[0] % groups == 0,
193 "Given groups=",
194 groups,
195 ", expected weight to be divisible by ",
196 groups,
197 " at dimension 0, but got weight of size ",
198 weight_sizes,
199 " instead");
200
201 if (!transposed) {
202 std::vector<int64_t> input_shape;
203 std::vector<int64_t> kernel_shape;
204 bool kernel_size_correct = true;
205
206 TORCH_CHECK(
207 input.size(1) == (weight_sizes[1] * groups),
208 "Given groups=",
209 groups,
210 ", weight of size ",
211 weight_sizes,
212 ", expected input",
213 input.sizes(),
214 " to have ",
215 (weight_sizes[1] * groups),
216 " channels, but got ",
217 input.size(1),
218 " channels instead");
219 TORCH_CHECK(
220 !bias.defined() ||
221 (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
222 "Given weight of size ",
223 weight_sizes,
224 ", expected bias to be 1-dimensional with ",
225 weight_sizes[0],
226 " elements",
227 ", but got bias of size ",
228 bias.sizes(),
229 " instead");
230
231 for (int i = 2; i < k; ++i) {
232 input_shape.push_back(input.size(i) + 2 * padding[i - 2]);
233 kernel_shape.push_back(dilation[i - 2] * (weight_sizes[i] - 1) + 1);
234 if (input_shape.back() < kernel_shape.back()) {
235 kernel_size_correct = false;
236 }
237 }
238
239 TORCH_CHECK(
240 input_shape.size() == kernel_shape.size(),
241 "Inconsistent shape between Input and Kernel");
242
243 if (!kernel_size_correct) {
244 std::ostringstream input_ss;
245 std::ostringstream kernel_ss;
246 std::ostringstream output_ss;
247 std::string separator = "";
248
249 for (int i = 0, len = input_shape.size(); i < len; ++i) {
250 input_ss << separator << input_shape[i];
251 kernel_ss << separator << kernel_shape[i];
252 separator = " x ";
253 }
254
255 TORCH_CHECK(
256 0,
257 "Calculated padded input size per channel: (",
258 input_ss.str(),
259 "). "
260 "Kernel size: (",
261 kernel_ss.str(),
262 "). Kernel size can't be greater than actual input size");
263 }
264 } else {
265 TORCH_CHECK(
266 input.size(1) == weight_sizes[0],
267 "Given transposed=",
268 transposed,
269 ", weight of size ",
270 weight_sizes,
271 ", expected input",
272 input.sizes(),
273 " to have ",
274 weight_sizes[0],
275 " channels, but got ",
276 input.size(1),
277 " channels instead");
278 TORCH_CHECK(
279 !bias.defined() ||
280 (bias.ndimension() == 1 &&
281 bias.size(0) == weight_sizes[1] * groups),
282 "Given transposed=",
283 transposed,
284 ", weight of size ",
285 weight_sizes,
286 ", expected bias to be 1-dimensional with ",
287 weight_sizes[1] * groups,
288 " elements",
289 ", but got bias of size ",
290 bias.sizes(),
291 " instead");
292 }
293 }
294
view4d(const at::Tensor & tensor)295 static at::Tensor view4d(const at::Tensor& tensor) {
296 TORCH_CHECK(
297 tensor.ndimension() == 3,
298 "expected 3D tensor, got tensor with ",
299 tensor.ndimension(),
300 " dimensions instead");
301 return tensor.unsqueeze(2);
302 }
303
view3d(const at::Tensor & tensor)304 static at::Tensor view3d(const at::Tensor& tensor) {
305 TORCH_CHECK(
306 tensor.ndimension() == 4,
307 "expected 4D tensor, got tensor with ",
308 tensor.ndimension(),
309 " dimensions instead");
310 return tensor.squeeze(2);
311 }
312
get_onednn_conv_sum_attr(const Tensor & input_r,const Tensor & weight_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,Tensor & accumu,double scale,Tensor & output,bool & is_fused,Attr attr=Attr (),bool force_inplace=false)313 Attr get_onednn_conv_sum_attr(
314 const Tensor& input_r,
315 const Tensor& weight_r,
316 IntArrayRef stride_,
317 IntArrayRef padding_,
318 IntArrayRef dilation_,
319 Tensor& accumu,
320 double scale,
321 Tensor& output,
322 bool& is_fused,
323 Attr attr = Attr(),
324 bool force_inplace = false) {
325 is_fused = true;
326 if (scale == 0.f)
327 return attr;
328
329 auto ndim = input_r.ndimension();
330 auto output_size = conv_dst_size(
331 ndim,
332 input_r.sizes(),
333 weight_r.sizes(),
334 padding_,
335 padding_,
336 stride_,
337 dilation_);
338 MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
339 auto input_fmt = input_r.suggest_memory_format();
340 auto input_is_cl = (input_fmt == at::MemoryFormat::ChannelsLast || input_fmt == at::MemoryFormat::ChannelsLast3d);
341 auto weight_fmt = weight_r.suggest_memory_format();
342 auto weight_is_cl = (weight_fmt == at::MemoryFormat::ChannelsLast || weight_fmt == at::MemoryFormat::ChannelsLast3d);
343
344 bool propagate_channels_last = input_is_cl || weight_is_cl;
345 if (propagate_channels_last)
346 mem_fmt = get_cl_tag_by_ndim(ndim);
347
348 Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
349 if (!onednn::binary_valid(out, accumu)) {
350 is_fused = false;
351 return attr;
352 }
353
354 // For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
355 // Thus we need the following transformation
356 // conv(src, wei) + scale * accumu
357 // scale * (1/scale * conv(src, wei) + sum (or binary))
358 if (scale != 1.f)
359 attr.append_post_eltwise(
360 /* scale */ 1.f,
361 /* alpha */ 1.f / scale,
362 /* beta */ 0.f,
363 attr.kind_with_linear);
364
365 if (force_inplace) {
366 // If sizes are the same, post sum is used.
367 output = accumu;
368 attr.append_post_sum(/* sum_scale */ 1.f);
369 } else {
370 // If sizes are different, post binary is used.
371 attr.append_post_binary(attr.kind_with_binary_add, accumu);
372 }
373
374 if (scale != 1.f)
375 attr.append_post_eltwise(
376 /* scale */ 1.f,
377 /* alpha */ scale,
378 /* beta */ 0.f,
379 attr.kind_with_linear);
380
381 return attr;
382 }
383
384 } // namespace impl
385
386 using namespace impl;
387
_convolution_out(Tensor & output_r,const Tensor & input_r,const Tensor & weight_r,const Tensor & bias_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,Attr attr,IntArrayRef pad_nd=IntArrayRef ({}))388 Tensor _convolution_out(
389 Tensor& output_r,
390 const Tensor& input_r,
391 const Tensor& weight_r,
392 const Tensor& bias_r,
393 IntArrayRef stride_,
394 IntArrayRef padding_,
395 IntArrayRef dilation_,
396 bool transposed_,
397 IntArrayRef output_padding_,
398 int64_t groups_,
399 Attr attr,
400 IntArrayRef pad_nd = IntArrayRef({})) {
401 auto ndim = input_r.ndimension();
402 TORCH_CHECK(
403 3 == ndim || 4 == ndim || 5 == ndim,
404 "convolution only supports 3D, 4D, 5D tensor");
405 // get computation format for Conv/TransposedConv
406 bool is_channels_last_suggested = use_channels_last_for_conv(input_r, weight_r, transposed_);
407
408 Tensor input = input_r, weight = weight_r;
409 // PyTorch does not support ChannelsLast1D case,
410 // thus we need the transformation here
411 if (ndim == 3) {
412 input = view4d(input_r);
413 weight = view4d(weight_r);
414 }
415 // ensure the input/weight/bias/output are congituous in desired format
416 at::MemoryFormat mfmt = is_channels_last_suggested
417 ? get_cl_tag_by_ndim(input.ndimension())
418 : at::MemoryFormat::Contiguous;
419 auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
420 input = input.contiguous(mfmt);
421 weight = weight.contiguous(mfmt);
422
423 auto k = weight.ndimension();
424 if (k == input.ndimension() + 1) {
425 k = input.ndimension();
426 }
427 int64_t dim = k - 2;
428 TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
429
430 ConvParams params;
431 if (ndim == 3) {
432 // PyTorch does not support ChannelsLast1D case,
433 // thus we need the transformation here
434 params.stride = stride_.vec();
435 params.padding = padding_.vec();
436 params.dilation = dilation_.vec();
437 params.transposed = transposed_;
438 params.output_padding = output_padding_.vec();
439 params.groups = groups_;
440 params.view1d_as_2d();
441 } else {
442 params.stride = expand_param_if_needed(stride_, "stride", dim);
443 // PyTorch default Conv padding should be a single integer value
444 // or a list of values to match the conv dimensions
445 // conv2d, the number of padding values should be 1 or 2
446 // conv3d, the number of padding values should be 1 or 3
447 // the padding value will be padded into both side of Conv input (D, H, W)
448 params.padding = expand_param_if_needed(padding_, "padding", dim);
449 params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
450 params.transposed = transposed_;
451 params.output_padding =
452 expand_param_if_needed(output_padding_, "output_padding", dim);
453 params.groups = groups_;
454 }
455 check_shape_forward(input, weight, bias, params, true);
456
457 Tensor output;
458 if (transposed_) {
459 // create output and propagate memory format
460 if (!output_r.defined()) {
461 auto dst_tz = deconv_dst_size(
462 input.sizes(),
463 weight.sizes(),
464 params.padding,
465 params.stride,
466 params.dilation,
467 params.output_padding,
468 params.groups);
469 output = at::empty(dst_tz, input.options(), mfmt);
470 }
471
472 onednn::deconvolution(
473 output,
474 input,
475 weight,
476 bias,
477 params.stride,
478 params.padding,
479 params.output_padding,
480 params.dilation,
481 params.groups,
482 attr);
483 } else {
484 // oneDNN supports padding the two sides of src with different values
485 // the padding order should be front_top_left and back_bottom_right
486 auto padding_front_top_left = params.padding;
487 auto padding_back_bottom_right = params.padding;
488
489 // PyTorch constant_pad_nd:
490 // can pad different value to the two sides of Conv input (W, H, D)
491 // (padding_left, padding_right,
492 // padding_top, padding_bottom,
493 // padding_front, padding_back)
494 if (pad_nd.vec().size() > 0) {
495 for (int i = 0; i < dim; ++i) {
496 padding_front_top_left[i] += pad_nd[2 * dim - 2 * i - 2]; // 4, 2, 0
497 padding_back_bottom_right[i] += pad_nd[2 * dim - 2 * i - 1]; // 5, 3, 1
498 }
499 }
500
501 // create output and propagate memory format
502 if (! output_r.defined()) {
503 auto dst_tz = conv_dst_size(
504 input.ndimension(),
505 input.sizes(),
506 weight.sizes(),
507 padding_front_top_left,
508 padding_back_bottom_right,
509 params.stride,
510 params.dilation);
511 output = at::empty(dst_tz, input.options(), mfmt);
512 }
513 onednn::convolution(
514 output,
515 input,
516 weight,
517 bias,
518 padding_front_top_left,
519 padding_back_bottom_right,
520 params.stride,
521 params.dilation,
522 params.groups,
523 attr);
524 }
525
526 if (ndim == 3) {
527 output = view3d(output);
528 }
529 if (output_r.defined() && !output_r.is_same(output)) {
530 output_r.copy_(output);
531 } else {
532 output_r = output;
533 }
534 return output_r;
535 }
536
_convolution(const Tensor & input_r,const Tensor & weight_r,const Tensor & bias_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,Attr attr)537 Tensor _convolution(
538 const Tensor& input_r,
539 const Tensor& weight_r,
540 const Tensor& bias_r,
541 IntArrayRef stride_,
542 IntArrayRef padding_,
543 IntArrayRef dilation_,
544 bool transposed_,
545 IntArrayRef output_padding_,
546 int64_t groups_,
547 Attr attr) {
548 Tensor output_r;
549 return _convolution_out(
550 output_r,
551 input_r,
552 weight_r,
553 bias_r,
554 stride_,
555 padding_,
556 dilation_,
557 transposed_,
558 output_padding_,
559 groups_,
560 attr);
561 }
562
convolution_overrideable(const Tensor & input_r,const Tensor & weight_r,const std::optional<at::Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_)563 Tensor convolution_overrideable(
564 const Tensor& input_r,
565 const Tensor& weight_r,
566 const std::optional<at::Tensor>& bias_r_opt,
567 IntArrayRef stride_,
568 IntArrayRef padding_,
569 IntArrayRef dilation_,
570 bool transposed_,
571 IntArrayRef output_padding_,
572 int64_t groups_) {
573 c10::MaybeOwned<Tensor> bias_r_maybe_owned =
574 at::borrow_from_optional_tensor(bias_r_opt);
575 const Tensor& bias_r = *bias_r_maybe_owned;
576
577 auto k = weight_r.ndimension();
578 at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
579 if (xpu_conv_use_channels_last(input_r, weight_r)) {
580 backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
581 }
582 Tensor input_c = input_r.contiguous(backend_memory_format);
583 Tensor weight_c = weight_r.contiguous(backend_memory_format);
584
585 return _convolution(
586 input_c,
587 weight_c,
588 bias_r,
589 stride_,
590 padding_,
591 dilation_,
592 transposed_,
593 output_padding_,
594 groups_,
595 Attr());
596 }
597
convolution_backward_overrideable(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)598 std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
599 const Tensor& grad_output,
600 const Tensor& input,
601 const Tensor& weight,
602 IntArrayRef stride,
603 IntArrayRef padding,
604 IntArrayRef dilation,
605 bool transposed,
606 IntArrayRef output_padding,
607 int64_t groups,
608 std::array<bool, 3> output_mask) {
609 auto ndim = input.ndimension();
610 TORCH_CHECK(
611 3 == ndim || 4 == ndim || 5 == ndim,
612 "convolution bwd only supports 3D, 4D, 5D tensor");
613 TORCH_CHECK(
614 grad_output.scalar_type() == ScalarType::Float ||
615 grad_output.scalar_type() == ScalarType::BFloat16 ||
616 grad_output.scalar_type() == ScalarType::Double ||
617 grad_output.scalar_type() == ScalarType::Half,
618 "so far only support float, bfloat16, half and double convolution backward in XPU backend, your data type is ",
619 grad_output.scalar_type());
620
621 bool is_channels_last_suggested = use_channels_last_for_conv(input, weight, transposed);
622
623 Tensor grad_output_, input_, weight_;
624 IntArrayRef stride_, padding_, dilation_, output_padding_;
625 bool transposed_;
626 int64_t groups_;
627 ConvParams params;
628 if (3 == ndim) {
629 grad_output_ = view4d(grad_output);
630 input_ = view4d(input);
631 weight_ = view4d(weight);
632 params.stride = stride.vec();
633 params.padding = padding.vec();
634 params.dilation = dilation.vec();
635 params.transposed = transposed;
636 params.output_padding = output_padding.vec();
637 params.groups = groups;
638 params.view1d_as_2d();
639 stride_ = params.stride;
640 padding_ = params.padding;
641 dilation_ = params.dilation;
642 transposed_ = params.transposed;
643 output_padding_ = params.output_padding;
644 groups_ = params.groups;
645 } else {
646 grad_output_ = grad_output;
647 input_ = input;
648 weight_ = weight;
649 stride_ = stride;
650 padding_ = padding;
651 dilation_ = dilation;
652 transposed_ = transposed;
653 output_padding_ = output_padding;
654 groups_ = groups;
655 }
656
657 // ensure the tensors are contiguous
658 auto mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input_.ndimension())
659 : at::MemoryFormat::Contiguous;
660 grad_output_ = grad_output_.contiguous(mfmt);
661 weight_ = weight_.contiguous(mfmt);
662 input_ = input_.contiguous(mfmt);
663
664 auto opt = grad_output_.options();
665 Tensor grad_input = at::empty(input_.sizes(), opt, mfmt);
666 Tensor grad_weight = at::empty(weight_.sizes(), opt, mfmt);
667 Tensor grad_bias;
668 if (output_mask[2])
669 grad_bias = at::empty({grad_output_.size(1)}, opt);
670
671 if (output_mask[0]) {
672 if (input.numel() > 0) {
673 if (transposed_) {
674 onednn::deconvolution_backward_data(
675 grad_input,
676 grad_output_,
677 weight_,
678 stride_,
679 padding_,
680 dilation_,
681 groups_,
682 output_mask[2]);
683 } else {
684 onednn::convolution_backward_data(
685 grad_input,
686 grad_output_,
687 weight_,
688 padding_,
689 padding_,
690 stride_,
691 dilation_,
692 groups_,
693 output_mask[2]);
694 }
695 }
696 }
697 if (output_mask[1] || output_mask[2]) {
698 if (input.numel() > 0) {
699 if (transposed_) {
700 onednn::deconvolution_backward_weights(
701 grad_weight,
702 grad_bias,
703 grad_output_,
704 input_,
705 stride_,
706 padding_,
707 dilation_,
708 groups_);
709 } else {
710 onednn::convolution_backward_weights(
711 grad_weight,
712 grad_bias,
713 grad_output_,
714 input_,
715 weight_.sizes(),
716 padding_,
717 padding_,
718 stride_,
719 dilation_,
720 groups_);
721 }
722 }
723 }
724
725 if (3 == ndim) {
726 if (output_mask[0])
727 grad_input = view3d(grad_input);
728 grad_weight = view3d(grad_weight);
729 }
730 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
731 }
732
TORCH_LIBRARY_IMPL(aten,XPU,m)733 TORCH_LIBRARY_IMPL(aten, XPU, m){
734 m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
735 m.impl("convolution_backward_overrideable", TORCH_FN(convolution_backward_overrideable));
736 }
737
738 } // namespace xpu
739 } // namespace at::native
740