xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/Conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <vector>
2 
3 #include <ATen/core/ATen_fwd.h>
4 #include <ATen/core/interned_strings.h>
5 #include <ATen/ops/full.h>
6 #include <ATen/ops/neg.h>
7 #include <c10/core/Scalar.h>
8 #include <c10/util/Exception.h>
9 #include <optional>
10 #include <ATen/native/utils/ParamUtils.h>
11 #include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
12 #include <torch/library.h>
13 #include <ATen/native/ConvUtils.h>
14 
15 using namespace dnnl;
16 using namespace at::native;
17 using namespace at::native::onednn;
18 
19 namespace at::native {
20 namespace xpu {
21 namespace impl {
22 
23 struct ConvParams {
24   std::vector<int64_t> stride;
25   std::vector<int64_t> padding;
26   std::vector<int64_t> dilation;
27   bool transposed;
28   std::vector<int64_t> output_padding;
29   int groups;
30   bool benchmark;
31   bool deterministic;
32 
33   bool is_strided() const;
34   bool is_dilated() const;
35   bool is_padded() const;
36   bool is_output_padding_neg() const;
37   bool is_output_padding_big() const;
38   bool is_padding_neg() const;
39   bool is_stride_nonpos() const;
40   void view1d_as_2d();
41   bool use_cpu_depthwise3x3_winograd(
42       const at::Tensor& input,
43       const at::Tensor& weight) const;
44   bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
45 };
46 
operator <<(std::ostream & out,const ConvParams & params)47 std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
48   out << "ConvParams {"
49       << "  stride = " << IntArrayRef{params.stride}
50       << "  padding = " << IntArrayRef{params.padding}
51       << "  dilation = " << IntArrayRef{params.dilation}
52       << "  transposed = " << params.transposed
53       << "  output_padding = " << IntArrayRef{params.output_padding}
54       << "  groups = " << params.groups << "  benchmark = " << params.benchmark
55       << "  deterministic = " << params.deterministic << "}";
56   return out;
57 }
58 
is_strided() const59 bool ConvParams::is_strided() const {
60   bool is_strided = false;
61   for (int s : stride) {
62     is_strided |= (s != 1);
63   }
64   return is_strided;
65 }
66 
is_dilated() const67 bool ConvParams::is_dilated() const {
68   bool is_dilated = false;
69   for (int d : dilation) {
70     is_dilated |= (d != 1);
71   }
72   return is_dilated;
73 }
74 
is_padded() const75 bool ConvParams::is_padded() const {
76   bool is_padded = false;
77   for (int p : padding) {
78     is_padded |= (p != 0);
79   }
80   return is_padded;
81 }
82 
is_output_padding_neg() const83 bool ConvParams::is_output_padding_neg() const {
84   bool is_non_neg = false;
85   for (int p : output_padding) {
86     is_non_neg |= (p < 0);
87   }
88   return is_non_neg;
89 }
90 
is_output_padding_big() const91 bool ConvParams::is_output_padding_big() const {
92   bool is_big = false;
93   for (size_t i = 0; i < output_padding.size(); i++) {
94     is_big |=
95         (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]);
96   }
97   return is_big;
98 }
99 
is_padding_neg() const100 bool ConvParams::is_padding_neg() const {
101   bool is_non_neg = false;
102   for (int p : padding) {
103     is_non_neg |= (p < 0);
104   }
105   return is_non_neg;
106 }
107 
is_stride_nonpos() const108 bool ConvParams::is_stride_nonpos() const {
109   bool is_nonpos = false;
110   for (int s : stride) {
111     is_nonpos |= (s <= 0);
112   }
113   return is_nonpos;
114 }
115 
view1d_as_2d()116 void ConvParams::view1d_as_2d() {
117   if (stride.size() == 1) {
118     stride.insert(stride.begin(), 1);
119     padding.insert(padding.begin(), 0);
120     dilation.insert(dilation.begin(), 1);
121     output_padding.insert(output_padding.begin(), 0);
122   }
123 }
124 
use_cpu_depthwise3x3_winograd(const at::Tensor & input,const at::Tensor & weight) const125 bool ConvParams::use_cpu_depthwise3x3_winograd(
126     const at::Tensor& input,
127     const at::Tensor& weight) const {
128   return false;
129 }
130 
is_depthwise(const at::Tensor & input,const at::Tensor & weight) const131 bool ConvParams::is_depthwise(const at::Tensor& input, const at::Tensor& weight)
132     const {
133   return !transposed && input.ndimension() == 4 && input.size(1) == groups &&
134       groups > 1 && // no point if there is only a single group
135       weight.size(0) % input.size(1) ==
136       0; // output channels must be a multiple of input channels
137 }
138 
check_shape_forward(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,const ConvParams & params,bool input_is_mkldnn)139 static void check_shape_forward(
140     const at::Tensor& input,
141     const at::Tensor& weight,
142     const at::Tensor& bias,
143     const ConvParams& params,
144     bool input_is_mkldnn) {
145   int64_t k = input.ndimension();
146   int64_t weight_dim = weight.ndimension();
147   std::vector<int64_t> weight_sizes(weight_dim);
148   if ((weight_dim == k + 1) && input_is_mkldnn) {
149     weight_sizes[0] = weight.size(0) * weight.size(1);
150     std::copy_n(weight.sizes().cbegin() + 2, k - 1, weight_sizes.begin() + 1);
151     weight_dim = k;
152   } else {
153     std::copy_n(weight.sizes().cbegin(), weight_dim, weight_sizes.begin());
154   }
155   int64_t groups = params.groups;
156   auto padding = params.padding;
157   auto output_padding = params.output_padding;
158   auto stride = params.stride;
159   auto dilation = params.dilation;
160   bool transposed = params.transposed;
161 
162   TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
163   TORCH_CHECK(
164       !params.is_output_padding_neg(),
165       "negative output_padding is not supported");
166   TORCH_CHECK(
167       !params.is_stride_nonpos(), "non-positive stride is not supported");
168 
169   TORCH_CHECK(
170       weight_dim == k,
171       "Expected ",
172       weight_dim,
173       "-dimensional input for ",
174       weight_dim,
175       "-dimensional weight ",
176       weight_sizes,
177       ", but got ",
178       k,
179       "-dimensional input of size ",
180       input.sizes(),
181       " instead");
182   TORCH_CHECK(
183       weight_sizes[0] >= groups,
184       "Given groups=",
185       groups,
186       ", expected weight to be at least ",
187       groups,
188       " at dimension 0, but got weight of size ",
189       weight_sizes,
190       " instead");
191   TORCH_CHECK(
192       weight_sizes[0] % groups == 0,
193       "Given groups=",
194       groups,
195       ", expected weight to be divisible by ",
196       groups,
197       " at dimension 0, but got weight of size ",
198       weight_sizes,
199       " instead");
200 
201   if (!transposed) {
202     std::vector<int64_t> input_shape;
203     std::vector<int64_t> kernel_shape;
204     bool kernel_size_correct = true;
205 
206     TORCH_CHECK(
207         input.size(1) == (weight_sizes[1] * groups),
208         "Given groups=",
209         groups,
210         ", weight of size ",
211         weight_sizes,
212         ", expected input",
213         input.sizes(),
214         " to have ",
215         (weight_sizes[1] * groups),
216         " channels, but got ",
217         input.size(1),
218         " channels instead");
219     TORCH_CHECK(
220         !bias.defined() ||
221             (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
222         "Given weight of size ",
223         weight_sizes,
224         ", expected bias to be 1-dimensional with ",
225         weight_sizes[0],
226         " elements",
227         ", but got bias of size ",
228         bias.sizes(),
229         " instead");
230 
231     for (int i = 2; i < k; ++i) {
232       input_shape.push_back(input.size(i) + 2 * padding[i - 2]);
233       kernel_shape.push_back(dilation[i - 2] * (weight_sizes[i] - 1) + 1);
234       if (input_shape.back() < kernel_shape.back()) {
235         kernel_size_correct = false;
236       }
237     }
238 
239     TORCH_CHECK(
240         input_shape.size() == kernel_shape.size(),
241         "Inconsistent shape between Input and Kernel");
242 
243     if (!kernel_size_correct) {
244       std::ostringstream input_ss;
245       std::ostringstream kernel_ss;
246       std::ostringstream output_ss;
247       std::string separator = "";
248 
249       for (int i = 0, len = input_shape.size(); i < len; ++i) {
250         input_ss << separator << input_shape[i];
251         kernel_ss << separator << kernel_shape[i];
252         separator = " x ";
253       }
254 
255       TORCH_CHECK(
256           0,
257           "Calculated padded input size per channel: (",
258           input_ss.str(),
259           "). "
260           "Kernel size: (",
261           kernel_ss.str(),
262           "). Kernel size can't be greater than actual input size");
263     }
264   } else {
265     TORCH_CHECK(
266         input.size(1) == weight_sizes[0],
267         "Given transposed=",
268         transposed,
269         ", weight of size ",
270         weight_sizes,
271         ", expected input",
272         input.sizes(),
273         " to have ",
274         weight_sizes[0],
275         " channels, but got ",
276         input.size(1),
277         " channels instead");
278     TORCH_CHECK(
279         !bias.defined() ||
280             (bias.ndimension() == 1 &&
281              bias.size(0) == weight_sizes[1] * groups),
282         "Given transposed=",
283         transposed,
284         ", weight of size ",
285         weight_sizes,
286         ", expected bias to be 1-dimensional with ",
287         weight_sizes[1] * groups,
288         " elements",
289         ", but got bias of size ",
290         bias.sizes(),
291         " instead");
292   }
293 }
294 
view4d(const at::Tensor & tensor)295 static at::Tensor view4d(const at::Tensor& tensor) {
296   TORCH_CHECK(
297       tensor.ndimension() == 3,
298       "expected 3D tensor, got tensor with ",
299       tensor.ndimension(),
300       " dimensions instead");
301   return tensor.unsqueeze(2);
302 }
303 
view3d(const at::Tensor & tensor)304 static at::Tensor view3d(const at::Tensor& tensor) {
305   TORCH_CHECK(
306       tensor.ndimension() == 4,
307       "expected 4D tensor, got tensor with ",
308       tensor.ndimension(),
309       " dimensions instead");
310   return tensor.squeeze(2);
311 }
312 
get_onednn_conv_sum_attr(const Tensor & input_r,const Tensor & weight_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,Tensor & accumu,double scale,Tensor & output,bool & is_fused,Attr attr=Attr (),bool force_inplace=false)313 Attr get_onednn_conv_sum_attr(
314     const Tensor& input_r,
315     const Tensor& weight_r,
316     IntArrayRef stride_,
317     IntArrayRef padding_,
318     IntArrayRef dilation_,
319     Tensor& accumu,
320     double scale,
321     Tensor& output,
322     bool& is_fused,
323     Attr attr = Attr(),
324     bool force_inplace = false) {
325   is_fused = true;
326   if (scale == 0.f)
327     return attr;
328 
329   auto ndim = input_r.ndimension();
330   auto output_size = conv_dst_size(
331       ndim,
332       input_r.sizes(),
333       weight_r.sizes(),
334       padding_,
335       padding_,
336       stride_,
337       dilation_);
338   MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
339   auto input_fmt = input_r.suggest_memory_format();
340   auto input_is_cl = (input_fmt == at::MemoryFormat::ChannelsLast || input_fmt == at::MemoryFormat::ChannelsLast3d);
341   auto weight_fmt = weight_r.suggest_memory_format();
342   auto weight_is_cl = (weight_fmt == at::MemoryFormat::ChannelsLast || weight_fmt == at::MemoryFormat::ChannelsLast3d);
343 
344   bool propagate_channels_last = input_is_cl || weight_is_cl;
345   if (propagate_channels_last)
346     mem_fmt = get_cl_tag_by_ndim(ndim);
347 
348   Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
349   if (!onednn::binary_valid(out, accumu)) {
350     is_fused = false;
351     return attr;
352   }
353 
354   // For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
355   // Thus we need the following transformation
356   // conv(src, wei) + scale * accumu
357   // scale * (1/scale * conv(src, wei) + sum (or binary))
358   if (scale != 1.f)
359     attr.append_post_eltwise(
360         /* scale */ 1.f,
361         /* alpha */ 1.f / scale,
362         /* beta */ 0.f,
363         attr.kind_with_linear);
364 
365   if (force_inplace) {
366     // If sizes are the same, post sum is used.
367     output = accumu;
368     attr.append_post_sum(/* sum_scale */ 1.f);
369   } else {
370     // If sizes are different, post binary is used.
371     attr.append_post_binary(attr.kind_with_binary_add, accumu);
372   }
373 
374   if (scale != 1.f)
375     attr.append_post_eltwise(
376         /* scale */ 1.f,
377         /* alpha */ scale,
378         /* beta */ 0.f,
379         attr.kind_with_linear);
380 
381   return attr;
382 }
383 
384 } // namespace impl
385 
386 using namespace impl;
387 
_convolution_out(Tensor & output_r,const Tensor & input_r,const Tensor & weight_r,const Tensor & bias_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,Attr attr,IntArrayRef pad_nd=IntArrayRef ({}))388 Tensor _convolution_out(
389     Tensor& output_r,
390     const Tensor& input_r,
391     const Tensor& weight_r,
392     const Tensor& bias_r,
393     IntArrayRef stride_,
394     IntArrayRef padding_,
395     IntArrayRef dilation_,
396     bool transposed_,
397     IntArrayRef output_padding_,
398     int64_t groups_,
399     Attr attr,
400     IntArrayRef pad_nd = IntArrayRef({})) {
401   auto ndim = input_r.ndimension();
402   TORCH_CHECK(
403       3 == ndim || 4 == ndim || 5 == ndim,
404       "convolution only supports 3D, 4D, 5D tensor");
405   // get computation format for Conv/TransposedConv
406   bool is_channels_last_suggested = use_channels_last_for_conv(input_r, weight_r, transposed_);
407 
408   Tensor input = input_r, weight = weight_r;
409   // PyTorch does not support ChannelsLast1D case,
410   // thus we need the transformation here
411   if (ndim == 3) {
412     input = view4d(input_r);
413     weight = view4d(weight_r);
414   }
415   // ensure the input/weight/bias/output are congituous in desired format
416   at::MemoryFormat mfmt = is_channels_last_suggested
417       ? get_cl_tag_by_ndim(input.ndimension())
418       : at::MemoryFormat::Contiguous;
419   auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
420   input = input.contiguous(mfmt);
421   weight = weight.contiguous(mfmt);
422 
423   auto k = weight.ndimension();
424   if (k == input.ndimension() + 1) {
425     k = input.ndimension();
426   }
427   int64_t dim = k - 2;
428   TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
429 
430   ConvParams params;
431   if (ndim == 3) {
432     // PyTorch does not support ChannelsLast1D case,
433     // thus we need the transformation here
434     params.stride = stride_.vec();
435     params.padding = padding_.vec();
436     params.dilation = dilation_.vec();
437     params.transposed = transposed_;
438     params.output_padding = output_padding_.vec();
439     params.groups = groups_;
440     params.view1d_as_2d();
441   } else {
442     params.stride = expand_param_if_needed(stride_, "stride", dim);
443     // PyTorch default Conv padding should be a single integer value
444     // or a list of values to match the conv dimensions
445     // conv2d, the number of padding values should be 1 or 2
446     // conv3d, the number of padding values should be 1 or 3
447     // the padding value will be padded into both side of Conv input (D, H, W)
448     params.padding = expand_param_if_needed(padding_, "padding", dim);
449     params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
450     params.transposed = transposed_;
451     params.output_padding =
452         expand_param_if_needed(output_padding_, "output_padding", dim);
453     params.groups = groups_;
454   }
455   check_shape_forward(input, weight, bias, params, true);
456 
457   Tensor output;
458   if (transposed_) {
459     // create output and propagate memory format
460     if (!output_r.defined()) {
461       auto dst_tz = deconv_dst_size(
462           input.sizes(),
463           weight.sizes(),
464           params.padding,
465           params.stride,
466           params.dilation,
467           params.output_padding,
468           params.groups);
469       output = at::empty(dst_tz, input.options(), mfmt);
470     }
471 
472     onednn::deconvolution(
473         output,
474         input,
475         weight,
476         bias,
477         params.stride,
478         params.padding,
479         params.output_padding,
480         params.dilation,
481         params.groups,
482         attr);
483   } else {
484     // oneDNN supports padding the two sides of src with different values
485     // the padding order should be front_top_left and back_bottom_right
486     auto padding_front_top_left = params.padding;
487     auto padding_back_bottom_right = params.padding;
488 
489     // PyTorch constant_pad_nd:
490     // can pad different value to the two sides of Conv input (W, H, D)
491     // (padding_left, padding_right,
492     //  padding_top, padding_bottom,
493     //  padding_front, padding_back)
494     if (pad_nd.vec().size() > 0) {
495       for (int i = 0; i < dim; ++i) {
496         padding_front_top_left[i] += pad_nd[2 * dim - 2 * i - 2]; // 4, 2, 0
497         padding_back_bottom_right[i] += pad_nd[2 * dim - 2 * i - 1]; // 5, 3, 1
498       }
499     }
500 
501     // create output and propagate memory format
502     if (! output_r.defined()) {
503       auto dst_tz = conv_dst_size(
504           input.ndimension(),
505           input.sizes(),
506           weight.sizes(),
507           padding_front_top_left,
508           padding_back_bottom_right,
509           params.stride,
510           params.dilation);
511       output = at::empty(dst_tz, input.options(), mfmt);
512     }
513     onednn::convolution(
514         output,
515         input,
516         weight,
517         bias,
518         padding_front_top_left,
519         padding_back_bottom_right,
520         params.stride,
521         params.dilation,
522         params.groups,
523         attr);
524   }
525 
526   if (ndim == 3) {
527     output = view3d(output);
528   }
529   if (output_r.defined() && !output_r.is_same(output)) {
530     output_r.copy_(output);
531   } else {
532     output_r = output;
533   }
534   return output_r;
535 }
536 
_convolution(const Tensor & input_r,const Tensor & weight_r,const Tensor & bias_r,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,Attr attr)537 Tensor _convolution(
538     const Tensor& input_r,
539     const Tensor& weight_r,
540     const Tensor& bias_r,
541     IntArrayRef stride_,
542     IntArrayRef padding_,
543     IntArrayRef dilation_,
544     bool transposed_,
545     IntArrayRef output_padding_,
546     int64_t groups_,
547     Attr attr) {
548   Tensor output_r;
549   return _convolution_out(
550       output_r,
551       input_r,
552       weight_r,
553       bias_r,
554       stride_,
555       padding_,
556       dilation_,
557       transposed_,
558       output_padding_,
559       groups_,
560       attr);
561 }
562 
convolution_overrideable(const Tensor & input_r,const Tensor & weight_r,const std::optional<at::Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_)563 Tensor convolution_overrideable(
564     const Tensor& input_r,
565     const Tensor& weight_r,
566     const std::optional<at::Tensor>& bias_r_opt,
567     IntArrayRef stride_,
568     IntArrayRef padding_,
569     IntArrayRef dilation_,
570     bool transposed_,
571     IntArrayRef output_padding_,
572     int64_t groups_) {
573   c10::MaybeOwned<Tensor> bias_r_maybe_owned =
574       at::borrow_from_optional_tensor(bias_r_opt);
575   const Tensor& bias_r = *bias_r_maybe_owned;
576 
577   auto k = weight_r.ndimension();
578   at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
579   if (xpu_conv_use_channels_last(input_r, weight_r)) {
580       backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
581   }
582   Tensor input_c = input_r.contiguous(backend_memory_format);
583   Tensor weight_c = weight_r.contiguous(backend_memory_format);
584 
585   return _convolution(
586       input_c,
587       weight_c,
588       bias_r,
589       stride_,
590       padding_,
591       dilation_,
592       transposed_,
593       output_padding_,
594       groups_,
595       Attr());
596 }
597 
convolution_backward_overrideable(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)598 std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
599     const Tensor& grad_output,
600     const Tensor& input,
601     const Tensor& weight,
602     IntArrayRef stride,
603     IntArrayRef padding,
604     IntArrayRef dilation,
605     bool transposed,
606     IntArrayRef output_padding,
607     int64_t groups,
608     std::array<bool, 3> output_mask) {
609   auto ndim = input.ndimension();
610   TORCH_CHECK(
611       3 == ndim || 4 == ndim || 5 == ndim,
612       "convolution bwd only supports 3D, 4D, 5D tensor");
613   TORCH_CHECK(
614       grad_output.scalar_type() == ScalarType::Float ||
615           grad_output.scalar_type() == ScalarType::BFloat16 ||
616           grad_output.scalar_type() == ScalarType::Double ||
617           grad_output.scalar_type() == ScalarType::Half,
618       "so far only support float, bfloat16, half and double convolution backward in XPU backend, your data type is ",
619       grad_output.scalar_type());
620 
621   bool is_channels_last_suggested = use_channels_last_for_conv(input, weight, transposed);
622 
623   Tensor grad_output_, input_, weight_;
624   IntArrayRef stride_, padding_, dilation_, output_padding_;
625   bool transposed_;
626   int64_t groups_;
627   ConvParams params;
628   if (3 == ndim) {
629     grad_output_ = view4d(grad_output);
630     input_ = view4d(input);
631     weight_ = view4d(weight);
632     params.stride = stride.vec();
633     params.padding = padding.vec();
634     params.dilation = dilation.vec();
635     params.transposed = transposed;
636     params.output_padding = output_padding.vec();
637     params.groups = groups;
638     params.view1d_as_2d();
639     stride_ = params.stride;
640     padding_ = params.padding;
641     dilation_ = params.dilation;
642     transposed_ = params.transposed;
643     output_padding_ = params.output_padding;
644     groups_ = params.groups;
645   } else {
646     grad_output_ = grad_output;
647     input_ = input;
648     weight_ = weight;
649     stride_ = stride;
650     padding_ = padding;
651     dilation_ = dilation;
652     transposed_ = transposed;
653     output_padding_ = output_padding;
654     groups_ = groups;
655   }
656 
657   // ensure the tensors are contiguous
658   auto mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input_.ndimension())
659       : at::MemoryFormat::Contiguous;
660   grad_output_ =  grad_output_.contiguous(mfmt);
661   weight_ = weight_.contiguous(mfmt);
662   input_ = input_.contiguous(mfmt);
663 
664   auto opt = grad_output_.options();
665   Tensor grad_input = at::empty(input_.sizes(), opt, mfmt);
666   Tensor grad_weight = at::empty(weight_.sizes(), opt, mfmt);
667   Tensor grad_bias;
668   if (output_mask[2])
669     grad_bias = at::empty({grad_output_.size(1)}, opt);
670 
671   if (output_mask[0]) {
672     if (input.numel() > 0) {
673       if (transposed_) {
674         onednn::deconvolution_backward_data(
675             grad_input,
676             grad_output_,
677             weight_,
678             stride_,
679             padding_,
680             dilation_,
681             groups_,
682             output_mask[2]);
683       } else {
684         onednn::convolution_backward_data(
685             grad_input,
686             grad_output_,
687             weight_,
688             padding_,
689             padding_,
690             stride_,
691             dilation_,
692             groups_,
693             output_mask[2]);
694       }
695     }
696   }
697   if (output_mask[1] || output_mask[2]) {
698     if (input.numel() > 0) {
699       if (transposed_) {
700         onednn::deconvolution_backward_weights(
701             grad_weight,
702             grad_bias,
703             grad_output_,
704             input_,
705             stride_,
706             padding_,
707             dilation_,
708             groups_);
709       } else {
710         onednn::convolution_backward_weights(
711             grad_weight,
712             grad_bias,
713             grad_output_,
714             input_,
715             weight_.sizes(),
716             padding_,
717             padding_,
718             stride_,
719             dilation_,
720             groups_);
721       }
722     }
723   }
724 
725   if (3 == ndim) {
726     if (output_mask[0])
727       grad_input = view3d(grad_input);
728     grad_weight = view3d(grad_weight);
729   }
730   return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
731 }
732 
TORCH_LIBRARY_IMPL(aten,XPU,m)733 TORCH_LIBRARY_IMPL(aten, XPU, m){
734   m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
735   m.impl("convolution_backward_overrideable", TORCH_FN(convolution_backward_overrideable));
736 }
737 
738 } // namespace xpu
739 } // namespace at::native
740