xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ConvUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/detail/CUDAHooksInterface.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <c10/util/env.h>
7 #include <c10/util/irange.h>
8 
9 #include <utility>
10 
11 namespace at::native {
12 
13 using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
14     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
15     at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
16 DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
17 using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
18     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
19     at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
20 DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
21 using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
22     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
23     at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
24 DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
25 using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
26     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
27     at::IntArrayRef, int64_t, std::array<bool,3>);
28 DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
29 using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
30     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
31     at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
32 DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
33 using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
34     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
35     at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
36 DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
37 using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
38     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
39     at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
40 DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
41 using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
42     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
43     at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
44 DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
45 using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
46     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
47     at::IntArrayRef, int64_t, std::array<bool,3>);
48 DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
49 using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
50     IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
51 DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
52 using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
53     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
54     at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
55 DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
56 using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
57     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
58     at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
59 DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
60 using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
61     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
62     at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
63 DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
64 using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
65     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
66     at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
67 DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
68 using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
69     const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
70     at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
71 DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
72 
73 namespace {
is_cudnnv8_heuristic_mode_b()74   bool is_cudnnv8_heuristic_mode_b() {
75     static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
76     return is_cudnnv8_heuristic_mode_b;
77   }
78 }
79 
cudnnv8_enabled_check_debug()80 inline bool cudnnv8_enabled_check_debug() {
81   static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
82   static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
83   static uint8_t cudnnv8_debugcount = 0;
84   if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
85     TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
86     cudnnv8_debugcount++;
87   }
88   return cudnnv8_flag == 1;
89 }
90 
cudnnv8_use_heur_mode_b()91 inline bool cudnnv8_use_heur_mode_b() {
92   return is_cudnnv8_heuristic_mode_b();
93 }
94 
95 // Keep in sync with py::enum_ in Module.cpp
96 enum class ConvBackend {
97   CudaDepthwise2d,
98   CudaDepthwise3d,
99   Cudnn,
100   CudnnTranspose,
101   Empty,
102   Miopen,
103   MiopenDepthwise,
104   MiopenTranspose,
105   Mkldnn,
106   MkldnnTranspose,
107   MkldnnEmpty,
108   NnpackSpatial,
109   Overrideable,
110   Slow2d,
111   Slow3d,
112   SlowDilated2d,
113   SlowDilated3d,
114   SlowTranspose2d,
115   SlowTranspose3d,
116   Winograd3x3Depthwise,
117   Xnnpack2d,
118   Mps,
119   MpsTranspose,
120 };
121 
122 // Overload for selecting the convolution backend from the full set of convolution inputs.
123 // This overload is exposed to python for testing, etc.
124 TORCH_API ConvBackend select_conv_backend(
125     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
126     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
127     bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
128 
129 TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
130     const Tensor& weight,
131     const ConvBackend backend);
132 
133 // ---------------------------------------------------------------------
134 //
135 // Math
136 //
137 // ---------------------------------------------------------------------
138 
139 constexpr int input_batch_size_dim = 0;  // also grad_input
140 constexpr int input_channels_dim = 1;
141 constexpr int output_batch_size_dim = 0;  // also grad_output
142 constexpr int output_channels_dim = 1;
143 constexpr int weight_output_channels_dim = 0;
144 constexpr int weight_input_channels_dim = 1;
145 
146 // Often written as 2 + max_dim (extra dims for batch size and channels)
147 constexpr int max_dim = 3;
148 
149 // ---------------------------------------------------------------------
150 //
151 // Checking
152 //
153 // ---------------------------------------------------------------------
154 
155 // Used on pad, stride and dilation
check_args(CheckedFrom c,IntArrayRef args,size_t expected_size,const char * arg_name)156 static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
157 {
158   TORCH_CHECK(args.size() <= expected_size,
159            "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
160            expected_size, " (while checking arguments for ", c, ")");
161   TORCH_CHECK(args.size() >= expected_size,
162            "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
163            expected_size, " (while checking arguments for ", c, ")");
164 
165   auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
166   if (num_negative_values > 0){
167     std::stringstream ss;
168     ss << arg_name << " should be greater than zero but got (";
169     std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
170     ss << args.back() <<  ")" << " (while checking arguments for " << c << ")";
171     AT_ERROR(ss.str());
172   }
173 }
174 
175 
176 // NOTE [ Convolution checks ]
177 //
178 // NB: For many call sites, it is not strictly necessary to check all of
179 // these relationships (for example, for forward convolution, we compute
180 // the size of output ourselves, so we don't actually need to check
181 // output.  However, writing a single function that does everything
182 // means we get to reuse it for both forwards and all backwards
183 // variants, even when the set of "real" inputs varies.  The magic of
184 // relational computing!
185 //
186 // (There is one downside, which is that it is slightly harder to write
187 // error messages which are able to distinguish between real inputs
188 // (which the user can change) and computed inputs (which the user can
189 // only indirectly affect).  It would be an interesting exercise to
190 // come up with a general framework to handle such situations.)
convolution_shape_check(CheckedFrom c,const TensorGeometryArg & input,const TensorGeometryArg & weight,const TensorGeometryArg & output,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)191 inline void convolution_shape_check(
192     CheckedFrom c,
193     const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
194     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
195 {
196   check_args(c, padding, input->dim() - 2, "padding");
197   check_args(c, stride, padding.size(), "stride");
198   check_args(c, dilation, padding.size(), "dilation");
199 
200   // Input
201   checkDimRange(c, input, 3, 6 /* exclusive */);
202   checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
203 
204   // Weight
205   checkSameDim(c, input, weight);
206 
207   // TODO: check that output->size() matches output_sizes
208   // TODO: check that weight matches output->sizes()
209   checkSameDim(c, input, output);
210 }
211 
212 // NB: conv_output_size and conv_input_size are not bijections,
213 // as conv_output_size loses information; this is why conv_input_size
214 // takes an extra output_padding argument to resolve the ambiguity.
215 
216 template <typename T>
217 inline std::vector<T> _conv_output_size(
218     ArrayRef<T> input_size, ArrayRef<T> weight_size,
219     ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
220 ) {
221   // ASSERT(input_size.size() > 2)
222   // ASSERT(input_size.size() == weight_size.size())
223   bool has_dilation = !dilation.empty();
224   auto dim = input_size.size();
225   std::vector<T> output_size(dim);
226   output_size[0] = input_size[input_batch_size_dim];
227   output_size[1] = weight_size[weight_output_channels_dim];
228   for (const auto d : c10::irange(2, dim)) {
229     auto dilation_ = has_dilation ? dilation[d - 2] : 1;
230     auto kernel = dilation_ * (weight_size[d] - 1) + 1;
231     output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
232   }
233   return output_size;
234 }
235 
236 inline std::vector<int64_t> conv_output_size(
237     IntArrayRef input_size, IntArrayRef weight_size,
238     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
239 ) {
240   return _conv_output_size(input_size, weight_size, padding, stride, dilation);
241 }
242 
243 inline std::vector<c10::SymInt> conv_output_size(
244     SymIntArrayRef input_size, SymIntArrayRef weight_size,
245     SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
246 ) {
247   return _conv_output_size(input_size, weight_size, padding, stride, dilation);
248 }
249 
250 template <typename T>
_conv_input_size(ArrayRef<T> output_size,ArrayRef<T> weight_size,ArrayRef<T> padding,ArrayRef<T> output_padding,ArrayRef<T> stride,ArrayRef<T> dilation,T groups)251 std::vector<T> _conv_input_size(
252     ArrayRef<T> output_size, ArrayRef<T> weight_size,
253     ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
254 ) {
255   // ASSERT(output_size.size() > 2)
256   // ASSERT(output_size.size() == weight_size.size())
257   auto dim = output_size.size();
258   std::vector<T> input_size(dim);
259   input_size[0] = output_size[output_batch_size_dim];
260   input_size[1] = weight_size[weight_input_channels_dim] * groups;
261   for (const auto d : c10::irange(2, dim)) {
262     auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
263     input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
264                      kernel + output_padding[d - 2];
265   }
266   return input_size;
267 }
268 
conv_input_size(SymIntArrayRef output_size,SymIntArrayRef weight_size,SymIntArrayRef padding,SymIntArrayRef output_padding,SymIntArrayRef stride,SymIntArrayRef dilation,c10::SymInt groups)269 inline std::vector<c10::SymInt> conv_input_size(
270     SymIntArrayRef output_size, SymIntArrayRef weight_size,
271     SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
272 ) {
273   return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
274 }
275 
conv_input_size(IntArrayRef output_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)276 inline std::vector<int64_t> conv_input_size(
277     IntArrayRef output_size, IntArrayRef weight_size,
278     IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
279 ) {
280   return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
281 }
282 
283 template <typename T>
_conv_weight_size(ArrayRef<T> input_size,ArrayRef<T> output_size,ArrayRef<T> padding,ArrayRef<T> output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)284 std::vector<T> _conv_weight_size(
285     ArrayRef<T> input_size, ArrayRef<T> output_size,
286     ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
287 ) {
288   auto dim = input_size.size();
289   std::vector<T> weight_size(dim);
290   weight_size[0] = output_size[1];
291   weight_size[1] = input_size[1] / groups;
292   for (const auto d : c10::irange(2, dim)) {
293     auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
294                + padding[d - 2] * 2 - output_padding[d - 2];
295     weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
296   }
297   return weight_size;
298 }
299 
conv_weight_size(SymIntArrayRef input_size,SymIntArrayRef output_size,SymIntArrayRef padding,SymIntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)300 inline std::vector<c10::SymInt> conv_weight_size(
301     SymIntArrayRef input_size, SymIntArrayRef output_size,
302     SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
303 ) {
304   return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
305 }
306 
conv_weight_size(IntArrayRef input_size,IntArrayRef output_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)307 inline std::vector<int64_t> conv_weight_size(
308     IntArrayRef input_size, IntArrayRef output_size,
309     IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
310 ) {
311   return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
312 }
313 
reshape_bias(int64_t dim,const Tensor & bias)314 inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
315   std::vector<int64_t> shape(dim, 1);
316   shape[1] = -1;
317   return bias.reshape(shape);
318 }
319 
cudnn_conv_suggest_memory_format(const at::Tensor & input,const at::Tensor & weight)320 inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
321   // disable NHWC for float64 input.
322   if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
323       input.scalar_type() == at::kDouble ||
324       weight.scalar_type() == at::kDouble) {
325     return at::MemoryFormat::Contiguous;
326   }
327   long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
328   auto input_memory_format = input.suggest_memory_format();
329   auto weight_memory_format = weight.suggest_memory_format();
330   auto weight_ndim = weight.ndimension();
331 
332   bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
333     (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
334     (weight_memory_format == at::MemoryFormat::ChannelsLast)
335   );
336   if (can_use_cudnn_channels_last_2d) {
337     return at::MemoryFormat::ChannelsLast;
338   }
339 
340   bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
341     (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
342     (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
343   );
344   if (can_use_cudnn_channels_last_3d) {
345     return at::MemoryFormat::ChannelsLast3d;
346   }
347 
348   return at::MemoryFormat::Contiguous;
349 }
350 
351 // controls whether emptyCache will be called following cudnn conv benchmarking
352 TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
353 TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
354 
355 
miopen_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)356 inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
357 
358   // disable NHWC for float64 input.
359   if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
360       input.scalar_type() == at::kDouble ||
361       weight.scalar_type() == at::kDouble) {
362     return false;
363   }
364 
365   bool can_use_miopen_channels_last_2d = false;
366   // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
367   // See #64427
368   static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
369 
370   auto input_memory_format = input.suggest_memory_format();
371   auto weight_memory_format = weight.suggest_memory_format();
372 
373   can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC &&  *PYTORCH_MIOPEN_SUGGEST_NHWC && (
374             ( (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
375             (weight_memory_format == at::MemoryFormat::ChannelsLast) )
376         );
377 
378   bool can_use_miopen_channels_last_3d = false;
379 
380   return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
381 }
382 
mkldnn_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)383 inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
384 
385   // disable NHWC for float64 input.
386   if (input.scalar_type() == at::kDouble ||
387       weight.scalar_type() == at::kDouble) {
388     return false;
389   }
390 
391   // disable NHWC for MkldnnCPU tensor.
392   if (input.is_mkldnn() || weight.is_mkldnn()) {
393     return false;
394   }
395 
396   auto input_memory_format = input.suggest_memory_format();
397   auto weight_memory_format = weight.suggest_memory_format();
398 
399   bool can_use_mkldnn_channels_last_2d =
400       (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
401       (weight_memory_format == at::MemoryFormat::ChannelsLast);
402 
403   bool can_use_mkldnn_channels_last_3d =
404       (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
405       (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
406 
407   return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
408 }
409 
thnn_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)410 inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
411 
412   auto input_memory_format = input.suggest_memory_format();
413   auto weight_memory_format = weight.suggest_memory_format();
414 
415   bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
416       (input_memory_format  == at::MemoryFormat::ChannelsLast) || (
417        weight_memory_format == at::MemoryFormat::ChannelsLast));
418 
419   return can_use_thnn_channels_last_2d;
420 }
421 
xpu_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)422 inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
423 
424   // check layout only for xpu tensor.
425   if (!input.is_xpu() || !weight.is_xpu()) {
426     return false;
427   }
428 
429   // disable NHWC for float64 input.
430   if (input.scalar_type() == at::kDouble ||
431       weight.scalar_type() == at::kDouble) {
432     return false;
433   }
434 
435   auto input_memory_format = input.suggest_memory_format();
436   auto weight_memory_format = weight.suggest_memory_format();
437 
438   bool can_use_xpu_channels_last_2d =
439       (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
440       (weight_memory_format == at::MemoryFormat::ChannelsLast);
441 
442   bool can_use_xpu_channels_last_3d =
443       (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
444       (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
445 
446   return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
447 }
448 
449 } // namespace at::native
450