xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <torch/library.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/native/ConvUtils.h>
6 #include <ATen/native/utils/ParamUtils.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_add_relu_native.h>
13 #include <ATen/ops/_to_dense_native.h>
14 #include <ATen/ops/convolution.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/empty_like.h>
17 #include <ATen/ops/mkldnn_convolution_native.h>
18 #endif
19 
20 #if !AT_MKLDNN_ENABLED()
21 
22 namespace at { namespace native {
23 
mkldnn_convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)24 Tensor mkldnn_convolution(
25     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
26     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
27   TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
28 }
29 
30 REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);
31 REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub);
32 REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub);
33 
34 }}
35 
36 #else // AT_MKLDNN_ENABLED
37 
38 #include <ATen/native/mkldnn/MKLDNNCommon.h>
39 #include <ATen/native/mkldnn/Utils.h>
40 #include <ATen/native/ConvUtils.h>
41 #include <c10/util/irange.h>
42 
43 namespace at { namespace native {
44 
45 // follow check rules from native/Convolution.cpp without transpose supported
check_shape_forward(const Tensor & input,const Tensor & weight,const Tensor & bias,const IntArrayRef & padding,const IntArrayRef & stride,const IntArrayRef & dilation,const int64_t groups)46 static void check_shape_forward(const Tensor& input,
47                                 const Tensor& weight,
48                                 const Tensor& bias,
49                                 const IntArrayRef& padding,
50                                 const IntArrayRef& stride,
51                                 const IntArrayRef& dilation,
52                                 const int64_t groups) {
53 #define MKLDNN_CONV_ARG_CHECK(IT, OP) std::any_of(IT.begin(), IT.end(), [](auto x) { return x OP 0; })
54   auto is_padding_neg = MKLDNN_CONV_ARG_CHECK(padding, <);
55   auto is_stride_nonpos = MKLDNN_CONV_ARG_CHECK(stride, <=);
56   auto is_dilation_nonpos = MKLDNN_CONV_ARG_CHECK(dilation, <=);
57 #undef MKLDNN_CONV_ARG_CHECK
58   TORCH_CHECK(!is_padding_neg, "negative padding is not supported");
59   TORCH_CHECK(!is_stride_nonpos, "non-positive stride is not supported");
60   TORCH_CHECK(!is_dilation_nonpos, "non-positive dilation is not supported");
61   TORCH_CHECK(groups > 0, "non-positive groups is not supported");
62 
63   int64_t k = input.ndimension();
64   const IntArrayRef& weight_sizes = weight.sizes();
65   int64_t weight_dim = weight_sizes.size();
66 
67   TORCH_CHECK(weight_dim == k,
68               "Expected ", weight_dim, "-dimensional input for ", weight_dim,
69               "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
70               input.sizes(), " instead");
71   TORCH_CHECK(weight_sizes[0] >= groups,
72               "Given groups=", groups, ", expected weight to be at least ", groups,
73               " at dimension 0, but got weight of size ", weight_sizes, " instead");
74   TORCH_CHECK(weight_sizes[0] % groups == 0,
75               "Given groups=", groups, ", expected weight to be divisible by ",
76               groups, " at dimension 0, but got weight of size [", weight_sizes,
77               "] instead");
78   TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups),
79               "Given groups=", groups, ", weight of size ", weight_sizes,
80               ", expected input", input.sizes(), " to have ",
81               (weight_sizes[1] * groups), " channels, but got ", input.size(1),
82               " channels instead");
83   TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
84               "Given weight of size ", weight_sizes,
85               ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
86               ", but got bias of size ", bias.sizes(), " instead");
87 
88   std::vector<int64_t> input_shape;
89   std::vector<int64_t> kernel_shape;
90   bool kernel_size_correct = true;
91 
92   for (const auto i : c10::irange(2, k)) {
93     input_shape.push_back(input.size(i) + 2 * padding[i-2]);
94     // log new kernel size considering dilation
95     kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
96     if (input_shape.back() < kernel_shape.back()) {
97       kernel_size_correct = false;
98     }
99   }
100 
101   TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");
102 
103   if (!kernel_size_correct) {
104     // If kernel size is incorrect
105     std::ostringstream input_ss;
106     std::ostringstream kernel_ss;
107     std::string separator = "";
108 
109     for (int i = 0, len = input_shape.size(); i < len; ++i) {
110       input_ss << separator << input_shape[i];
111       kernel_ss << separator << kernel_shape[i];
112       separator = " x ";
113     }
114 
115     TORCH_CHECK(false, "Calculated padded input size per channel: (", input_ss.str(), "). "
116                 "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
117   }
118 }
119 
120 #define MKLDNNTensor(itensor, options)                                  \
121   new_with_itensor_mkldnn(                                              \
122       std::move(itensor),                                               \
123       optTypeMetaToScalarType(options.dtype_opt()),                     \
124       options.device_opt())
125 
126 // Note [MKLDNN Convolution Memory Formats]
127 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128 // MKLDNN has 3 types of memory formats in convolution:
129 //
130 // In case memory format passed from PyTorch (aka. user layout)
131 // differs from the internal layout which MKLDNN used, a `reorder` is needed;
132 // otherwise when user layout is identical to internal layout,
133 // MKLDNN uses a memory `view` upon an existing CPU tensor.
134 //
135 // 1. NCHW (CPU tensor, contiguous)
136 //  input reorder:  NCHW(user) -> Blocked(internal)
137 //  weight reorder: OIHW(user) -> Blocked(internal)
138 //  output reorder: Blocked(internal) -> NCHW(user)
139 //
140 // 2. NHWC: (CPU tensor, channels last)
141 //  input view:     NHWC(user) -> NHWC(internal)
142 //  weight reorder: OHWI(user) -> Blocked(internal)
143 //  output view:    NHWC(internal) -> NHWC(user)
144 //
145 // 3. Blocked (MKLDNN tensor):
146 //  By explicitly converting a tensor to mkldnn, e.g. `x.to_mkldnn()`,
147 //  blocked format will propagate between layers. Input, output will be in blocked format.
148 //
149 //  For inference case, weight can be prepacked into blocked format by
150 //  (so as to save weight reoder overhead):
151 //      model = torch.utils.mkldnn.to_mkldnn(model)
152 //
153 //  For training case, grad_output can be CPU tensor or MKLDNN tensor,
154 //  but weight/bias and grad_weight/grad_bias are always CPU tensor.
155 //
156 
mkldnn_convolution_memory_format(int64_t dims,bool is_channels_last)157 static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
158    auto memory_format =  at::MemoryFormat::Contiguous;
159    if (is_channels_last) {
160       memory_format = dims == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
161    }
162    return memory_format;
163 }
164 
_mkldnn_convolution_out(const Tensor & input_t,const Tensor & weight_t,const Tensor & bias,std::vector<int64_t> & output_sizes,ideep::tensor & y,IntArrayRef stride,IntArrayRef dilation,IntArrayRef padding,int64_t groups,bool is_channels_last,const ideep::attr_t & op_attr)165 static void _mkldnn_convolution_out (
166     const Tensor& input_t,
167     const Tensor& weight_t,
168     const Tensor& bias,
169     std::vector<int64_t>& output_sizes,
170     ideep::tensor& y,
171     IntArrayRef stride,
172     IntArrayRef dilation,
173     IntArrayRef padding,
174     int64_t groups,
175     bool is_channels_last,
176     const ideep::attr_t& op_attr) {
177   auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
178   auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
179   auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
180   const ideep::tensor x = itensor_from_tensor(input, /*from_const_data_ptr*/true);
181   const ideep::tensor w = itensor_from_tensor(weight, /*from_const_data_ptr*/true);
182   if (bias.defined()) {
183     const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
184     ideep::convolution_forward::compute_v3(
185         x,
186         w,
187         b,
188         {output_sizes.cbegin(), output_sizes.cend()},
189         y,
190         {stride.begin(), stride.end()},
191         {dilation.begin(), dilation.end()},
192         {padding.begin(), padding.end()},
193         {padding.begin(), padding.end()},
194         groups,
195         is_channels_last,
196         op_attr);
197   } else {
198     ideep::convolution_forward::compute_v3(
199         x,
200         w,
201         {output_sizes.cbegin(), output_sizes.cend()},
202         y,
203         {stride.begin(), stride.end()},
204         {dilation.begin(), dilation.end()},
205         {padding.begin(), padding.end()},
206         {padding.begin(), padding.end()},
207         groups,
208         is_channels_last,
209         op_attr);
210   }
211 }
212 
_mkldnn_convolution(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool use_channels_last,c10::string_view attr="none",torch::List<std::optional<at::Scalar>> scalars=torch::List<std::optional<at::Scalar>> (),std::optional<c10::string_view> algorithm=std::nullopt)213 static Tensor _mkldnn_convolution(
214     const Tensor& input_t,
215     const Tensor& weight_t,
216     const std::optional<Tensor>& bias_opt,
217     IntArrayRef padding,
218     IntArrayRef stride,
219     IntArrayRef dilation,
220     int64_t groups,
221     bool use_channels_last,
222     c10::string_view attr = "none",
223     torch::List<std::optional<at::Scalar>> scalars =
224         torch::List<std::optional<at::Scalar>>(),
225     std::optional<c10::string_view> algorithm = std::nullopt) {
226   ideep::attr_t op_attr = ideep::attr_t();
227   if (attr != "none") {
228     auto it = fusion_unary_attr_map().find(attr);
229     TORCH_CHECK(
230         it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
231     op_attr = it->second(scalars, algorithm);
232   }
233   // See [Note: hacky wrapper removal for optional tensor]
234   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
235   const Tensor& bias = *bias_maybe_owned;
236 
237   mkldnn_check_low_precision(input_t.scalar_type(), "mkldnn_convolution");
238 
239   int64_t dim = input_t.ndimension() - 2;
240   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
241   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
242   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
243 
244   check_shape_forward(input_t, weight_t, bias, padding_expanded, stride_expanded, dilation_expanded, groups);
245 
246   auto memory_format =
247       mkldnn_convolution_memory_format(input_t.ndimension(), use_channels_last);
248 
249   auto output_sizes = conv_output_size(input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded);
250   auto output = at::empty({0}, input_t.options());
251   ideep::tensor y;
252   if (use_channels_last) {
253     output.resize_(output_sizes, memory_format);
254     y = itensor_from_tensor(output);
255   }
256   _mkldnn_convolution_out(
257       input_t,
258       weight_t,
259       bias,
260       output_sizes,
261       y,
262       stride_expanded,
263       dilation_expanded,
264       padding_expanded,
265       groups,
266       use_channels_last,
267       op_attr);
268 
269   if (input_t.is_mkldnn()) {
270     return MKLDNNTensor(y, input_t.options());
271   } else if (!use_channels_last) {
272     return mkldnn_to_dense(MKLDNNTensor(y, input_t.options()));
273   } else {
274     return output;
275   }
276 }
277 
mkldnn_convolution(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)278 Tensor mkldnn_convolution(
279     const Tensor& input_t,
280     const Tensor& weight_t,
281     const std::optional<Tensor>& bias_opt,
282     IntArrayRef padding,
283     IntArrayRef stride,
284     IntArrayRef dilation,
285     int64_t groups) {
286   bool use_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
287   return _mkldnn_convolution(
288       input_t,
289       weight_t,
290       bias_opt,
291       padding,
292       stride,
293       dilation,
294       groups,
295       use_channels_last);
296 }
297 
298 namespace{
mkldnn_convolution_pointwise(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::string_view attr,torch::List<std::optional<at::Scalar>> scalars,std::optional<c10::string_view> algorithm)299 Tensor mkldnn_convolution_pointwise(
300     const Tensor& input_t,
301     const Tensor& weight_t,
302     const std::optional<Tensor>& bias_opt,
303     IntArrayRef padding,
304     IntArrayRef stride,
305     IntArrayRef dilation,
306     int64_t groups,
307     c10::string_view attr,
308     torch::List<std::optional<at::Scalar>> scalars,
309     std::optional<c10::string_view> algorithm) {
310   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
311   bool use_channels_last =
312       weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
313   return _mkldnn_convolution(
314       input_t,
315       weight_t,
316       bias_opt,
317       padding,
318       stride,
319       dilation,
320       groups,
321       use_channels_last,
322       attr,
323       scalars,
324       algorithm);
325 }
326 
327 // Fuse convolution+binary_op+unary_op for good performance, which doing such
328 // operation: output=unary_op(binary_op(conv(input_t, ...), other_t, alpha)).
329 // The binary_attr means which binary_op is, it can be "add", or
330 // other binary operation. the unary_attr means which unary_op is,
331 // it can be "relu" or other unary operation, if it is none, meaning that
332 // there doesn't have a unary post op. unary_scalars and unary_algorithm
333 // are the parameters of the unary op, such as "hardtanh" has scalar parameters,
334 // "gelu" has algorithm parameters.
mkldnn_convolution_pointwise_binary(const Tensor & input_t,const Tensor & other_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::string_view binary_attr,std::optional<at::Scalar> alpha,std::optional<c10::string_view> unary_attr,torch::List<std::optional<at::Scalar>> unary_scalars,std::optional<c10::string_view> unary_algorithm)335 Tensor mkldnn_convolution_pointwise_binary(
336     const Tensor& input_t,
337     const Tensor& other_t,
338     const Tensor& weight_t,
339     const std::optional<Tensor>& bias_opt,
340     IntArrayRef padding,
341     IntArrayRef stride,
342     IntArrayRef dilation,
343     int64_t groups,
344     c10::string_view binary_attr,
345     std::optional<at::Scalar> alpha,
346     std::optional<c10::string_view> unary_attr,
347     torch::List<std::optional<at::Scalar>> unary_scalars,
348     std::optional<c10::string_view> unary_algorithm) {
349   TORCH_CHECK(
350       input_t.ndimension() == 4 || input_t.ndimension() == 5,
351       "mkldnn_convolution_pointwise_binary: currently only support 2d and 3d")
352   TORCH_CHECK(
353       !alpha.has_value() || alpha.value().to<float>() == 1.0,
354       "mkldnn_convolution_pointwise_binary: the alpha value should be none or 1.0");
355 
356   c10::MaybeOwned<Tensor> bias_maybe_owned =
357       at::borrow_from_optional_tensor(bias_opt);
358   const Tensor& bias = *bias_maybe_owned;
359 
360   // Make sure inputs have same type(device, layout, dtype), device is cpu and
361   // dtype is float, bfloat16 or half.
362   check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias);
363 
364   int64_t dim = input_t.ndimension() - 2;
365   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
366   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
367   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
368   check_shape_forward(
369       input_t, weight_t, bias, padding_expanded, stride_expanded, dilation_expanded, groups);
370 
371   auto output_sizes = conv_output_size(
372       input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded);
373   // TODO: support broadcast binary fusion.
374   TORCH_CHECK(
375       output_sizes == other_t.sizes(),
376       "Binary Fusion's inputs should have same shape");
377   // Only calling fusion path for channels_last path.
378   // TODO: OneDNN doesn't optimize well for groups > 1 case, it will be enabled
379   // at next OneDNN release.
380   bool use_channels_last =
381       weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
382   bool can_be_fused = groups == 1 && use_channels_last;
383 
384   c10::string_view unary_attr_value = "none";
385   ideep::algorithm unary_alg;
386   if (unary_attr.has_value()) {
387     auto it_unary = fusion_unary_alg_map().find(unary_attr.value());
388     // Now, we only support conv+binary+relu.
389     TORCH_CHECK(
390         it_unary != fusion_unary_alg_map().end(),
391         "Unary Fusion behavior undefined.");
392     unary_attr_value = unary_attr.value();
393     unary_alg = it_unary->second;
394   }
395   auto it_binary = fusion_binary_alg_map().find(binary_attr);
396   TORCH_CHECK(
397       it_binary != fusion_binary_alg_map().end(),
398       "Binary Fusion behavior undefined.");
399   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
400   if (can_be_fused) {
401     auto memory_format =
402         mkldnn_convolution_memory_format(input_t.ndimension(), true);
403     auto input = input_t.contiguous(memory_format);
404     auto weight =
405         weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
406     auto other = other_t.contiguous(memory_format);
407     auto output = at::empty_like(other);
408     const ideep::tensor x = itensor_from_tensor(input);
409     const ideep::tensor w = itensor_from_tensor(weight);
410     const ideep::tensor z = itensor_from_tensor(other);
411     ideep::tensor y = itensor_from_tensor(output);
412     auto output_size = other.sizes().vec();
413     ideep::tag format_tag = ideep::tag::nhwc;
414     if (input_t.ndimension() == 5) {
415       format_tag = ideep::tag::ndhwc;
416     }
417     auto other_desc = ideep::tensor::desc(
418         output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag);
419 
420     ideep::attr_t op_attr;
421     ideep::post_ops po;
422     po.append_binary(it_binary->second, other_desc);
423     if (unary_attr_value != "none") {
424       po.append_eltwise(unary_alg, 0.f, 0.f);
425     }
426     op_attr.set_post_ops(po);
427 
428     if (bias.defined()) {
429       const ideep::tensor b = itensor_from_tensor(bias);
430       ideep::convolution_forward::compute_binary(
431           x,
432           z,
433           w,
434           b,
435           output_size,
436           y,
437           stride_expanded,
438           dilation_expanded,
439           padding_expanded,
440           padding_expanded,
441           groups,
442           /* is_channels_last */ true,
443           op_attr);
444     } else {
445       ideep::convolution_forward::compute_binary(
446           x,
447           z,
448           w,
449           output_size,
450           y,
451           stride_expanded,
452           dilation_expanded,
453           padding_expanded,
454           padding_expanded,
455           groups,
456           /* is_channels_last */ true,
457           op_attr);
458     }
459     return output;
460   } else {
461     // Fallback case, if inputs are not channels last or have different dtype,
462     // OneDNN fusion may have performance regression.
463     Tensor output;
464     if (weight_t.is_mkldnn()) {
465       output = _mkldnn_convolution(
466           input_t, weight_t, bias, padding_expanded, stride_expanded, dilation, groups, true);
467     } else {
468       output = at::convolution(
469           input_t, weight_t, bias, stride_expanded, padding_expanded, dilation_expanded, false, 0, groups);
470     }
471     if (binary_attr == "add" && unary_attr_value != "none") {
472       output = at::native::add_relu_(output, other_t);
473       return output;
474     }
475     if (binary_attr == "add") {
476       output.add_(other_t);
477     } else if (binary_attr == "sub") {
478       output.sub_(other_t);
479     } else if (binary_attr == "mul") {
480       output.mul_(other_t);
481     } else {
482       output.div_(other_t);
483     }
484     if (unary_attr_value != "none") {
485       output.relu_();
486     }
487     return output;
488   }
489 }
490 
491 // Fuse convolution+binary_op+unary_op for good performance, which doing
492 // such operation: other_t=unary_op(binary_op(conv(input_t, ...), other_t,
493 // alpha)). The binary_attr means which binary_op is, it can be "add", or other
494 // binary operation. the unary_attr means which unary_op is, it can be "relu" or
495 // other unary operation, if it is none, meaning that there doesn't have a unary
496 // post op. unary_scalars and unary_algorithm are the parameters of the unary
497 // op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters.
498 
mkldnn_convolution_pointwise_binary_(Tensor & other_t,const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::string_view binary_attr,std::optional<at::Scalar> alpha,std::optional<c10::string_view> unary_attr,torch::List<std::optional<at::Scalar>> unary_scalars,std::optional<c10::string_view> unary_algorithm)499 Tensor& mkldnn_convolution_pointwise_binary_(
500     Tensor& other_t,
501     const Tensor& input_t,
502     const Tensor& weight_t,
503     const std::optional<Tensor>& bias_opt,
504     IntArrayRef padding,
505     IntArrayRef stride,
506     IntArrayRef dilation,
507     int64_t groups,
508     c10::string_view binary_attr,
509     std::optional<at::Scalar> alpha,
510     std::optional<c10::string_view> unary_attr,
511     torch::List<std::optional<at::Scalar>> unary_scalars,
512     std::optional<c10::string_view> unary_algorithm) {
513   // other_t += convolution(...), other_t = unary(other_t)
514   TORCH_CHECK(
515       input_t.ndimension() == 4 || input_t.ndimension() == 5,
516       "mkldnn_convolution_add_: currently only support 2d and 3d")
517   TORCH_CHECK(
518       binary_attr == "add",
519       "mkldnn_convolution_pointwise_binary_: only support binary op fusion")
520   TORCH_CHECK(
521       !alpha.has_value() || alpha.value().to<float>() == 1.0,
522       "mkldnn_convolution_pointwise_binary: the alpha value for the binary op should be none(meaning 1.0) or 1.0");
523   TORCH_CHECK(
524       !unary_attr.has_value() || unary_attr.value() == "relu",
525       "mkldnn_convolution_pointwise_binary: only support none or relu unary op fusion after binary op");
526 
527   c10::MaybeOwned<Tensor> bias_maybe_owned =
528       at::borrow_from_optional_tensor(bias_opt);
529   const Tensor& bias = *bias_maybe_owned;
530 
531   // Make sure inputs have same type(device, layout, dtype), device is cpu and
532   // dtype is float, bfloat16 or half.
533   check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias);
534   int64_t dim = input_t.ndimension() - 2;
535   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
536   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
537   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
538   check_shape_forward(
539       input_t, weight_t, bias, padding, stride, dilation, groups);
540 
541   auto output_sizes = conv_output_size(
542       input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded);
543   TORCH_CHECK(
544       output_sizes == other_t.sizes(),
545       "Add Fusion's inputs should have same shape");
546   // Only calling fusion path for channels_last path and the output is contiguous tensor(channels_last).
547   bool can_be_fused = (weight_t.is_mkldnn() ||
548                        mkldnn_conv_use_channels_last(input_t, weight_t)) &&
549       (other_t.is_contiguous(at::MemoryFormat::ChannelsLast) ||
550        other_t.is_contiguous(at::MemoryFormat::ChannelsLast3d));
551   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
552   if (can_be_fused) {
553     ideep::tensor y = itensor_from_tensor(other_t);
554     ideep::attr_t op_attr;
555     if (unary_attr.has_value()) {
556       op_attr = ideep::attr_t::residual();
557     } else {
558       op_attr = ideep::attr_t::fuse_sum();
559     }
560     _mkldnn_convolution_out(
561         input_t,
562         weight_t,
563         bias,
564         output_sizes,
565         y,
566         stride_expanded,
567         dilation_expanded,
568         padding_expanded,
569         groups,
570         true,
571         op_attr);
572   } else {
573     // Fallback case, if inputs are not channels last or have different dtype,
574     // OneDNN fusion may have performance regression.
575     Tensor output;
576     if (weight_t.is_mkldnn()) {
577       output = _mkldnn_convolution(
578           input_t, weight_t, bias, padding_expanded, stride_expanded, dilation_expanded, groups, true);
579     } else {
580       output = at::convolution(
581           input_t, weight_t, bias, stride_expanded, padding_expanded, dilation_expanded, false, 0, groups);
582     }
583     if (unary_attr.has_value()) {
584       other_t = at::native::add_relu_(other_t, output);
585     } else {
586       other_t.add_(output);
587     }
588   }
589   return other_t;
590 }
591 
_original_deconv_weight_size(const Tensor & weight_t,int64_t groups)592 std::vector<int64_t> _original_deconv_weight_size(
593     const Tensor& weight_t,
594     int64_t groups) {
595   TORCH_CHECK(weight_t.is_mkldnn() || weight_t.is_meta(), "expects weight_t to be mkldnn or meta tensor");
596   // The size of weight_t is the prepacked size.
597   //  Groups > 1: [g*o, i/g, ...]
598   //  Groups == 1: [o, i, ...]
599   // Returns original weight size in [i, o, ...]
600   auto dim = weight_t.sizes().size();
601   TORCH_CHECK(dim > 2);
602 
603   std::vector<int64_t> weight_IOHW_sizes(dim);
604   if (groups > 1) {
605     weight_IOHW_sizes[0] = weight_t.sizes()[1] * groups;
606     weight_IOHW_sizes[1] = weight_t.sizes()[0] / groups;
607   } else {
608     weight_IOHW_sizes[0] = weight_t.sizes()[1];
609     weight_IOHW_sizes[1] = weight_t.sizes()[0];
610   }
611   for (const auto d : c10::irange(2, dim)) {
612     weight_IOHW_sizes[d] = weight_t.sizes()[d];
613   }
614   return weight_IOHW_sizes;
615 }
616 
617 
_mkldnn_convolution_transpose(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool use_channels_last,c10::string_view attr="none",torch::List<std::optional<at::Scalar>> scalars=torch::List<std::optional<at::Scalar>> (),std::optional<c10::string_view> algorithm=std::nullopt)618 Tensor _mkldnn_convolution_transpose(
619     const Tensor& input_t,
620     const Tensor& weight_t,
621     const std::optional<Tensor>& bias_opt,
622     IntArrayRef padding,
623     IntArrayRef output_padding,
624     IntArrayRef stride,
625     IntArrayRef dilation,
626     int64_t groups,
627     bool use_channels_last,
628     c10::string_view attr = "none",
629     torch::List<std::optional<at::Scalar>> scalars =
630         torch::List<std::optional<at::Scalar>>(),
631     std::optional<c10::string_view> algorithm = std::nullopt) {
632   ideep::attr_t op_attr = ideep::attr_t();
633   if (attr != "none") {
634     auto it = fusion_unary_attr_map().find(attr);
635     TORCH_CHECK(it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
636     op_attr = it->second(scalars, algorithm);
637   }
638 
639   // See [Note: hacky wrapper removal for optional tensor]
640   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
641   const Tensor& bias = *bias_maybe_owned;
642 
643   mkldnn_check_low_precision(input_t.scalar_type(), "mkldnn_convolution_transpose");
644 
645   std::vector<int64_t> weight_IOHW_sizes = weight_t.is_mkldnn() ? _original_deconv_weight_size(weight_t, groups) : weight_t.sizes().vec();
646 
647   auto memory_format =
648       mkldnn_convolution_memory_format(input_t.ndimension(), use_channels_last);
649 
650   auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
651   auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
652 
653   int64_t dim = input.ndimension() - 2;
654   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
655   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
656   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
657   const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", dim);
658   auto output_sizes = conv_input_size(input.sizes(), weight_IOHW_sizes, padding_expanded, output_padding_expanded, stride_expanded, dilation_expanded, groups);
659   auto output = at::empty({0}, input.options());
660 
661   const ideep::tensor x = itensor_from_tensor(input, /*from_const_data_ptr*/true);
662 
663   ideep::tensor w = itensor_from_tensor(weight, /*from_const_data_ptr*/true);
664   if (!weight.is_mkldnn()) {
665     // mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
666     // while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
667     w.transpose_(0, 1);
668   }
669 
670   ideep::tensor y;
671   if (use_channels_last) {
672     output.resize_(output_sizes, memory_format);
673     y = itensor_from_tensor(output);
674   }
675 
676   if (bias.defined()) {
677     const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
678     ideep::convolution_transpose_forward::compute_v3(
679         x,
680         w,
681         b,
682         output_sizes,
683         y,
684         stride_expanded,
685         padding_expanded,
686         padding_r(padding_expanded, output_padding_expanded),
687         dilation.vec(),
688         groups,
689         use_channels_last,
690         op_attr);
691   } else {
692     ideep::convolution_transpose_forward::compute_v3(
693         x,
694         w,
695         output_sizes,
696         y,
697         stride_expanded,
698         padding_expanded,
699         padding_r(padding_expanded, output_padding_expanded),
700         dilation.vec(),
701         groups,
702         use_channels_last,
703         op_attr);
704   }
705   if (input.is_mkldnn()) {
706     return MKLDNNTensor(y, input.options());
707   } else if (!use_channels_last) {
708     return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
709   } else {
710     return output;
711   }
712 }
713 
mkldnn_convolution_transpose_pointwise(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::string_view attr,torch::List<std::optional<at::Scalar>> scalars,std::optional<c10::string_view> algorithm)714 Tensor mkldnn_convolution_transpose_pointwise(
715     const Tensor& input_t,
716     const Tensor& weight_t,
717     const std::optional<Tensor>& bias_opt,
718     IntArrayRef padding,
719     IntArrayRef output_padding,
720     IntArrayRef stride,
721     IntArrayRef dilation,
722     int64_t groups,
723     c10::string_view attr,
724     torch::List<std::optional<at::Scalar>> scalars,
725     std::optional<c10::string_view> algorithm) {
726   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
727   bool use_channels_last =
728       weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
729   return _mkldnn_convolution_transpose(
730       input_t,
731       weight_t,
732       bias_opt,
733       padding,
734       output_padding,
735       stride,
736       dilation,
737       groups,
738       use_channels_last,
739       attr,
740       scalars,
741       algorithm
742   );
743 }
744 
mkldnn_convolution_transpose_pointwise_meta(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::string_view attr,torch::List<std::optional<at::Scalar>> scalars,std::optional<c10::string_view> algorithm)745 Tensor mkldnn_convolution_transpose_pointwise_meta(
746     const Tensor& input_t,
747     const Tensor& weight_t,
748     const std::optional<Tensor>& bias_opt,
749     IntArrayRef padding,
750     IntArrayRef output_padding,
751     IntArrayRef stride,
752     IntArrayRef dilation,
753     int64_t groups,
754     c10::string_view attr,
755     torch::List<std::optional<at::Scalar>> scalars,
756     std::optional<c10::string_view> algorithm) {
757 
758   std::vector<int64_t> weight_IOHW_sizes = _original_deconv_weight_size(weight_t, groups);
759   int64_t dim = input_t.ndimension() - 2;
760   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
761   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
762   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
763   const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", dim);
764   auto output_sizes = conv_input_size(input_t.sizes(), weight_IOHW_sizes, padding_expanded , output_padding_expanded , stride_expanded , dilation_expanded , groups);
765 
766   auto output = at::empty(output_sizes, input_t.options());
767   return output;
768 }
769 
mkldnn_convolution_backward_input(IntArrayRef input_size,const Tensor & grad_output,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,bool is_channels_last)770 Tensor mkldnn_convolution_backward_input(
771     IntArrayRef input_size,
772     const Tensor& grad_output,
773     const Tensor& weight,
774     IntArrayRef padding,
775     IntArrayRef stride,
776     IntArrayRef dilation,
777     int64_t groups,
778     bool bias_defined,
779     bool is_channels_last) {
780   auto grad_input = at::empty({0}, grad_output.options());
781 
782   auto grad_y = itensor_from_tensor(grad_output, /*from_const_data_ptr*/true);
783   auto w = itensor_view_from_dense(weight, /*from_const_data_ptr*/true);
784 
785   ideep::tensor grad_x;
786   if (is_channels_last) {
787     auto memory_format = mkldnn_convolution_memory_format(grad_output.ndimension(), is_channels_last);
788     grad_input.resize_(input_size, memory_format);
789     grad_x = itensor_from_tensor(grad_input);
790   }
791   ideep::convolution_backward_data::compute_v2(
792       grad_y,
793       w,
794       input_size.vec(),
795       grad_x,
796       stride.vec(),
797       dilation.vec(),
798       padding.vec(),
799       padding.vec(),
800       groups,
801       is_channels_last);
802 
803   if (grad_output.is_mkldnn()) {
804     return MKLDNNTensor(grad_x, grad_output.options());
805   } else if (!is_channels_last){
806     return mkldnn_to_dense(MKLDNNTensor(grad_x, grad_output.options()));
807   } else {
808     return grad_input;
809   }
810 }
811 
mkldnn_convolution_backward_weights(IntArrayRef weight_size,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,bool is_channels_last)812 std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
813     IntArrayRef weight_size,
814     const Tensor& grad_output,
815     const Tensor& input,
816     IntArrayRef padding,
817     IntArrayRef stride,
818     IntArrayRef dilation,
819     int64_t groups,
820     bool bias_defined,
821     bool is_channels_last) {
822   const ideep::tensor grad_y = itensor_from_tensor(grad_output, /*from_const_data_ptr*/true);
823   const ideep::tensor x = itensor_from_tensor(input, /*from_const_data_ptr*/true);
824 
825   ideep::tensor grad_w, grad_b;
826   if (bias_defined) {
827     ideep::convolution_backward_weights::compute_v2(
828         x,
829         grad_y,
830         weight_size.vec(),
831         grad_w,
832         grad_b,
833         stride.vec(),
834         dilation.vec(),
835         padding.vec(),
836         padding.vec(),
837         groups,
838         is_channels_last);
839   } else {
840     ideep::convolution_backward_weights::compute_v2(
841         x,
842         grad_y,
843         weight_size.vec(),
844         grad_w,
845         stride.vec(),
846         dilation.vec(),
847         padding.vec(),
848         padding.vec(),
849         groups,
850         is_channels_last);
851   }
852 
853   if (!is_channels_last) {
854     return std::make_tuple(
855         mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())),
856         bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
857   } else {
858     auto memory_format = mkldnn_convolution_memory_format(grad_output.ndimension(), is_channels_last);
859     return std::make_tuple(
860         mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())).to(memory_format),
861         bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
862   }
863 }
864 
mkldnn_convolution_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,std::array<bool,3> output_mask)865 std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
866     const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t,
867     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
868 {
869   bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
870   auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
871   Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);
872 
873   Tensor input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
874   Tensor weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
875   int64_t dim = input.ndimension() - 2;
876   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
877   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
878   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
879   Tensor grad_input, grad_weight, grad_bias;
880   if (output_mask[0]) {
881     grad_input = mkldnn_convolution_backward_input(
882       input.sizes(), grad_output, weight, padding_expanded, stride_expanded, dilation_expanded, groups, output_mask[2], is_channels_last);
883   }
884   if (output_mask[1] || output_mask[2]) {
885     std::tie(grad_weight, grad_bias) = mkldnn_convolution_backward_weights(
886       weight.sizes(), grad_output, input, padding_expanded, stride_expanded, dilation_expanded, groups, output_mask[2], is_channels_last);
887   }
888   return std::make_tuple(grad_input, grad_weight, grad_bias);
889 }
890 }
891 
892 REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward);
893 
894 namespace{
mkldnn_convolution_transpose(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)895 Tensor mkldnn_convolution_transpose(
896     const Tensor& input,
897     const Tensor& weight,
898     const std::optional<Tensor>& bias_opt,
899     IntArrayRef padding,
900     IntArrayRef output_padding,
901     IntArrayRef stride,
902     IntArrayRef dilation,
903     int64_t groups)
904 {
905   bool use_channels_last = mkldnn_conv_use_channels_last(input, weight);
906   return _mkldnn_convolution_transpose(
907       input,
908       weight,
909       bias_opt,
910       padding,
911       output_padding,
912       stride,
913       dilation,
914       groups,
915       use_channels_last
916   );
917 }
918 
mkldnn_convolution_transpose_backward_input(IntArrayRef input_size,const Tensor & grad_output,const Tensor & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,bool is_channels_last)919 Tensor mkldnn_convolution_transpose_backward_input(
920     IntArrayRef input_size,
921     const Tensor& grad_output,
922     const Tensor& weight,
923     IntArrayRef padding,
924     IntArrayRef output_padding,
925     IntArrayRef stride,
926     IntArrayRef dilation,
927     int64_t groups,
928     bool bias_defined,
929     bool is_channels_last) {
930   auto grad_input = at::empty({0}, grad_output.options());
931 
932   auto grad_y = itensor_from_tensor(grad_output, /*from_const_data_ptr*/true);
933   auto w = itensor_view_from_dense(weight, /*from_const_data_ptr*/true).transpose_(0, 1);
934 
935   ideep::tensor grad_x;
936   if (is_channels_last) {
937     auto memory_format = mkldnn_convolution_memory_format(grad_output.ndimension(), is_channels_last);
938     grad_input.resize_(input_size, memory_format);
939     grad_x = itensor_from_tensor(grad_input);
940   }
941   ideep::convolution_transpose_backward_data::compute_v3(
942       grad_y,
943       w,
944       input_size.vec(),
945       grad_x,
946       stride.vec(),
947       padding.vec(),
948       padding_r(padding, output_padding),
949       dilation.vec(),
950       groups,
951       is_channels_last);
952 
953   if (grad_output.is_mkldnn()) {
954     return MKLDNNTensor(grad_x, grad_output.options());
955   } else if (!is_channels_last){
956     return mkldnn_to_dense(MKLDNNTensor(grad_x, grad_output.options()));
957   } else {
958     return grad_input;
959   }
960 }
961 
mkldnn_convolution_transpose_backward_weights(IntArrayRef weight_size,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,bool is_channels_last)962 std::tuple<Tensor,Tensor> mkldnn_convolution_transpose_backward_weights(
963     IntArrayRef weight_size,
964     const Tensor& grad_output,
965     const Tensor& input,
966     IntArrayRef padding,
967     IntArrayRef output_padding,
968     IntArrayRef stride,
969     IntArrayRef dilation,
970     int64_t groups,
971     bool bias_defined,
972     bool is_channels_last) {
973   auto grad_y = itensor_from_tensor(grad_output, /*from_const_data_ptr*/true);
974   auto x = itensor_from_tensor(input, /*from_const_data_ptr*/true);
975 
976   ideep::tensor grad_w, grad_b;
977   if (bias_defined) {
978     ideep::convolution_transpose_backward_weights::compute_v3(
979         x,
980         grad_y,
981         weight_size.vec(),
982         grad_w,
983         grad_b,
984         stride.vec(),
985         padding.vec(),
986         padding_r(padding, output_padding),
987         dilation.vec(),
988         groups,
989         is_channels_last);
990   } else {
991     ideep::convolution_transpose_backward_weights::compute_v3(
992         x,
993         grad_y,
994         weight_size.vec(),
995         grad_w,
996         stride.vec(),
997         padding.vec(),
998         padding_r(padding, output_padding),
999         dilation.vec(),
1000         groups,
1001         is_channels_last);
1002   }
1003 
1004   if (!is_channels_last) {
1005     return std::make_tuple(
1006         mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())),
1007         bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
1008   } else {
1009     auto memory_format = mkldnn_convolution_memory_format(grad_output.ndimension(), is_channels_last);
1010     return std::make_tuple(
1011         mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())).to(memory_format),
1012         bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
1013   }
1014 }
1015 
mkldnn_convolution_transpose_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,std::array<bool,3> output_mask)1016 std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
1017     const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t,
1018     IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1019     std::array<bool,3> output_mask)
1020 {
1021   bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
1022   auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
1023   Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);
1024   auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
1025   auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
1026   int64_t dim = input.ndimension() - 2;
1027   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
1028   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
1029   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
1030   const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", dim);
1031 
1032   Tensor grad_input, grad_weight, grad_bias;
1033   if (output_mask[0]) {
1034     grad_input = mkldnn_convolution_transpose_backward_input(
1035         input.sizes(), grad_output, weight, padding_expanded , output_padding_expanded , stride_expanded , dilation_expanded , groups, output_mask[2], is_channels_last);
1036   }
1037   if (output_mask[1] || output_mask[2]) {
1038     std::tie(grad_weight, grad_bias) = mkldnn_convolution_transpose_backward_weights(
1039         weight.sizes(), grad_output, input, padding_expanded , output_padding_expanded , stride_expanded , dilation_expanded , groups, output_mask[2], is_channels_last);
1040   }
1041   return std::make_tuple(grad_input, grad_weight, grad_bias);
1042 }
1043 }
1044 
1045 REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose);
1046 REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward);
1047 
TORCH_LIBRARY_IMPL(mkldnn,CPU,m)1048 TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
1049   m.impl(
1050       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
1051       TORCH_FN(mkldnn_convolution_pointwise));
1052   m.impl(
1053       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
1054       TORCH_FN(mkldnn_convolution_pointwise_binary));
1055   m.impl(
1056       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
1057       TORCH_FN(mkldnn_convolution_pointwise_binary_));
1058   m.impl(
1059       TORCH_SELECTIVE_NAME("mkldnn::_convolution_transpose_pointwise"),
1060       TORCH_FN(mkldnn_convolution_transpose_pointwise));
1061 }
1062 
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)1063 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
1064   m.impl(
1065       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
1066       TORCH_FN(mkldnn_convolution_pointwise));
1067   m.impl(
1068       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
1069       TORCH_FN(mkldnn_convolution_pointwise_binary));
1070   m.impl(
1071       TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
1072       TORCH_FN(mkldnn_convolution_pointwise_binary_));
1073   m.impl(
1074       TORCH_SELECTIVE_NAME("mkldnn::_convolution_transpose_pointwise"),
1075       TORCH_FN(mkldnn_convolution_transpose_pointwise));
1076 }
1077 
TORCH_LIBRARY_IMPL(mkldnn,Meta,m)1078 TORCH_LIBRARY_IMPL(mkldnn, Meta, m) {
1079   m.impl(
1080       TORCH_SELECTIVE_NAME("mkldnn::_convolution_transpose_pointwise"),
1081       TORCH_FN(mkldnn_convolution_transpose_pointwise_meta));
1082 }
1083 }}  // namespace at::native
1084 
1085 #endif
1086