1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/detail/CUDAHooksInterface.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <c10/util/env.h>
7 #include <c10/util/irange.h>
8
9 #include <utility>
10
11 namespace at::native {
12
13 using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
14 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
15 at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
16 DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
17 using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
18 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
19 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
20 DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
21 using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
22 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
23 at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
24 DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
25 using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
26 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
27 at::IntArrayRef, int64_t, std::array<bool,3>);
28 DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
29 using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
30 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
31 at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
32 DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
33 using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
34 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
35 at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
36 DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
37 using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
38 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
39 at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
40 DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
41 using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
42 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
43 at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
44 DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
45 using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
46 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
47 at::IntArrayRef, int64_t, std::array<bool,3>);
48 DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
49 using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
50 IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
51 DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
52 using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
53 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
54 at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
55 DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
56 using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
57 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
58 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
59 DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
60 using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
61 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
62 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
63 DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
64 using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
65 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
66 at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
67 DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
68 using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
69 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
70 at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
71 DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
72
73 namespace {
is_cudnnv8_heuristic_mode_b()74 bool is_cudnnv8_heuristic_mode_b() {
75 static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
76 return is_cudnnv8_heuristic_mode_b;
77 }
78 }
79
cudnnv8_enabled_check_debug()80 inline bool cudnnv8_enabled_check_debug() {
81 static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
82 static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
83 static uint8_t cudnnv8_debugcount = 0;
84 if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
85 TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
86 cudnnv8_debugcount++;
87 }
88 return cudnnv8_flag == 1;
89 }
90
cudnnv8_use_heur_mode_b()91 inline bool cudnnv8_use_heur_mode_b() {
92 return is_cudnnv8_heuristic_mode_b();
93 }
94
95 // Keep in sync with py::enum_ in Module.cpp
96 enum class ConvBackend {
97 CudaDepthwise2d,
98 CudaDepthwise3d,
99 Cudnn,
100 CudnnTranspose,
101 Empty,
102 Miopen,
103 MiopenDepthwise,
104 MiopenTranspose,
105 Mkldnn,
106 MkldnnTranspose,
107 MkldnnEmpty,
108 NnpackSpatial,
109 Overrideable,
110 Slow2d,
111 Slow3d,
112 SlowDilated2d,
113 SlowDilated3d,
114 SlowTranspose2d,
115 SlowTranspose3d,
116 Winograd3x3Depthwise,
117 Xnnpack2d,
118 Mps,
119 MpsTranspose,
120 };
121
122 // Overload for selecting the convolution backend from the full set of convolution inputs.
123 // This overload is exposed to python for testing, etc.
124 TORCH_API ConvBackend select_conv_backend(
125 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
126 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
127 bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
128
129 TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
130 const Tensor& weight,
131 const ConvBackend backend);
132
133 // ---------------------------------------------------------------------
134 //
135 // Math
136 //
137 // ---------------------------------------------------------------------
138
139 constexpr int input_batch_size_dim = 0; // also grad_input
140 constexpr int input_channels_dim = 1;
141 constexpr int output_batch_size_dim = 0; // also grad_output
142 constexpr int output_channels_dim = 1;
143 constexpr int weight_output_channels_dim = 0;
144 constexpr int weight_input_channels_dim = 1;
145
146 // Often written as 2 + max_dim (extra dims for batch size and channels)
147 constexpr int max_dim = 3;
148
149 // ---------------------------------------------------------------------
150 //
151 // Checking
152 //
153 // ---------------------------------------------------------------------
154
155 // Used on pad, stride and dilation
check_args(CheckedFrom c,IntArrayRef args,size_t expected_size,const char * arg_name)156 static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
157 {
158 TORCH_CHECK(args.size() <= expected_size,
159 "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
160 expected_size, " (while checking arguments for ", c, ")");
161 TORCH_CHECK(args.size() >= expected_size,
162 "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
163 expected_size, " (while checking arguments for ", c, ")");
164
165 auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
166 if (num_negative_values > 0){
167 std::stringstream ss;
168 ss << arg_name << " should be greater than zero but got (";
169 std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
170 ss << args.back() << ")" << " (while checking arguments for " << c << ")";
171 AT_ERROR(ss.str());
172 }
173 }
174
175
176 // NOTE [ Convolution checks ]
177 //
178 // NB: For many call sites, it is not strictly necessary to check all of
179 // these relationships (for example, for forward convolution, we compute
180 // the size of output ourselves, so we don't actually need to check
181 // output. However, writing a single function that does everything
182 // means we get to reuse it for both forwards and all backwards
183 // variants, even when the set of "real" inputs varies. The magic of
184 // relational computing!
185 //
186 // (There is one downside, which is that it is slightly harder to write
187 // error messages which are able to distinguish between real inputs
188 // (which the user can change) and computed inputs (which the user can
189 // only indirectly affect). It would be an interesting exercise to
190 // come up with a general framework to handle such situations.)
convolution_shape_check(CheckedFrom c,const TensorGeometryArg & input,const TensorGeometryArg & weight,const TensorGeometryArg & output,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)191 inline void convolution_shape_check(
192 CheckedFrom c,
193 const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
194 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
195 {
196 check_args(c, padding, input->dim() - 2, "padding");
197 check_args(c, stride, padding.size(), "stride");
198 check_args(c, dilation, padding.size(), "dilation");
199
200 // Input
201 checkDimRange(c, input, 3, 6 /* exclusive */);
202 checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
203
204 // Weight
205 checkSameDim(c, input, weight);
206
207 // TODO: check that output->size() matches output_sizes
208 // TODO: check that weight matches output->sizes()
209 checkSameDim(c, input, output);
210 }
211
212 // NB: conv_output_size and conv_input_size are not bijections,
213 // as conv_output_size loses information; this is why conv_input_size
214 // takes an extra output_padding argument to resolve the ambiguity.
215
216 template <typename T>
217 inline std::vector<T> _conv_output_size(
218 ArrayRef<T> input_size, ArrayRef<T> weight_size,
219 ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
220 ) {
221 // ASSERT(input_size.size() > 2)
222 // ASSERT(input_size.size() == weight_size.size())
223 bool has_dilation = !dilation.empty();
224 auto dim = input_size.size();
225 std::vector<T> output_size(dim);
226 output_size[0] = input_size[input_batch_size_dim];
227 output_size[1] = weight_size[weight_output_channels_dim];
228 for (const auto d : c10::irange(2, dim)) {
229 auto dilation_ = has_dilation ? dilation[d - 2] : 1;
230 auto kernel = dilation_ * (weight_size[d] - 1) + 1;
231 output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
232 }
233 return output_size;
234 }
235
236 inline std::vector<int64_t> conv_output_size(
237 IntArrayRef input_size, IntArrayRef weight_size,
238 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
239 ) {
240 return _conv_output_size(input_size, weight_size, padding, stride, dilation);
241 }
242
243 inline std::vector<c10::SymInt> conv_output_size(
244 SymIntArrayRef input_size, SymIntArrayRef weight_size,
245 SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
246 ) {
247 return _conv_output_size(input_size, weight_size, padding, stride, dilation);
248 }
249
250 template <typename T>
_conv_input_size(ArrayRef<T> output_size,ArrayRef<T> weight_size,ArrayRef<T> padding,ArrayRef<T> output_padding,ArrayRef<T> stride,ArrayRef<T> dilation,T groups)251 std::vector<T> _conv_input_size(
252 ArrayRef<T> output_size, ArrayRef<T> weight_size,
253 ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
254 ) {
255 // ASSERT(output_size.size() > 2)
256 // ASSERT(output_size.size() == weight_size.size())
257 auto dim = output_size.size();
258 std::vector<T> input_size(dim);
259 input_size[0] = output_size[output_batch_size_dim];
260 input_size[1] = weight_size[weight_input_channels_dim] * groups;
261 for (const auto d : c10::irange(2, dim)) {
262 auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
263 input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
264 kernel + output_padding[d - 2];
265 }
266 return input_size;
267 }
268
conv_input_size(SymIntArrayRef output_size,SymIntArrayRef weight_size,SymIntArrayRef padding,SymIntArrayRef output_padding,SymIntArrayRef stride,SymIntArrayRef dilation,c10::SymInt groups)269 inline std::vector<c10::SymInt> conv_input_size(
270 SymIntArrayRef output_size, SymIntArrayRef weight_size,
271 SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
272 ) {
273 return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
274 }
275
conv_input_size(IntArrayRef output_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)276 inline std::vector<int64_t> conv_input_size(
277 IntArrayRef output_size, IntArrayRef weight_size,
278 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
279 ) {
280 return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
281 }
282
283 template <typename T>
_conv_weight_size(ArrayRef<T> input_size,ArrayRef<T> output_size,ArrayRef<T> padding,ArrayRef<T> output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)284 std::vector<T> _conv_weight_size(
285 ArrayRef<T> input_size, ArrayRef<T> output_size,
286 ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
287 ) {
288 auto dim = input_size.size();
289 std::vector<T> weight_size(dim);
290 weight_size[0] = output_size[1];
291 weight_size[1] = input_size[1] / groups;
292 for (const auto d : c10::irange(2, dim)) {
293 auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
294 + padding[d - 2] * 2 - output_padding[d - 2];
295 weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
296 }
297 return weight_size;
298 }
299
conv_weight_size(SymIntArrayRef input_size,SymIntArrayRef output_size,SymIntArrayRef padding,SymIntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)300 inline std::vector<c10::SymInt> conv_weight_size(
301 SymIntArrayRef input_size, SymIntArrayRef output_size,
302 SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
303 ) {
304 return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
305 }
306
conv_weight_size(IntArrayRef input_size,IntArrayRef output_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups)307 inline std::vector<int64_t> conv_weight_size(
308 IntArrayRef input_size, IntArrayRef output_size,
309 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
310 ) {
311 return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
312 }
313
reshape_bias(int64_t dim,const Tensor & bias)314 inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
315 std::vector<int64_t> shape(dim, 1);
316 shape[1] = -1;
317 return bias.reshape(shape);
318 }
319
cudnn_conv_suggest_memory_format(const at::Tensor & input,const at::Tensor & weight)320 inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
321 // disable NHWC for float64 input.
322 if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
323 input.scalar_type() == at::kDouble ||
324 weight.scalar_type() == at::kDouble) {
325 return at::MemoryFormat::Contiguous;
326 }
327 long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
328 auto input_memory_format = input.suggest_memory_format();
329 auto weight_memory_format = weight.suggest_memory_format();
330 auto weight_ndim = weight.ndimension();
331
332 bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
333 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
334 (weight_memory_format == at::MemoryFormat::ChannelsLast)
335 );
336 if (can_use_cudnn_channels_last_2d) {
337 return at::MemoryFormat::ChannelsLast;
338 }
339
340 bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
341 (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
342 (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
343 );
344 if (can_use_cudnn_channels_last_3d) {
345 return at::MemoryFormat::ChannelsLast3d;
346 }
347
348 return at::MemoryFormat::Contiguous;
349 }
350
351 // controls whether emptyCache will be called following cudnn conv benchmarking
352 TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
353 TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
354
355
miopen_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)356 inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
357
358 // disable NHWC for float64 input.
359 if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
360 input.scalar_type() == at::kDouble ||
361 weight.scalar_type() == at::kDouble) {
362 return false;
363 }
364
365 bool can_use_miopen_channels_last_2d = false;
366 // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
367 // See #64427
368 static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
369
370 auto input_memory_format = input.suggest_memory_format();
371 auto weight_memory_format = weight.suggest_memory_format();
372
373 can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
374 ( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
375 (weight_memory_format == at::MemoryFormat::ChannelsLast) )
376 );
377
378 bool can_use_miopen_channels_last_3d = false;
379
380 return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
381 }
382
mkldnn_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)383 inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
384
385 // disable NHWC for float64 input.
386 if (input.scalar_type() == at::kDouble ||
387 weight.scalar_type() == at::kDouble) {
388 return false;
389 }
390
391 // disable NHWC for MkldnnCPU tensor.
392 if (input.is_mkldnn() || weight.is_mkldnn()) {
393 return false;
394 }
395
396 auto input_memory_format = input.suggest_memory_format();
397 auto weight_memory_format = weight.suggest_memory_format();
398
399 bool can_use_mkldnn_channels_last_2d =
400 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
401 (weight_memory_format == at::MemoryFormat::ChannelsLast);
402
403 bool can_use_mkldnn_channels_last_3d =
404 (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
405 (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
406
407 return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
408 }
409
thnn_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)410 inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
411
412 auto input_memory_format = input.suggest_memory_format();
413 auto weight_memory_format = weight.suggest_memory_format();
414
415 bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
416 (input_memory_format == at::MemoryFormat::ChannelsLast) || (
417 weight_memory_format == at::MemoryFormat::ChannelsLast));
418
419 return can_use_thnn_channels_last_2d;
420 }
421
xpu_conv_use_channels_last(const at::Tensor & input,const at::Tensor & weight)422 inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
423
424 // check layout only for xpu tensor.
425 if (!input.is_xpu() || !weight.is_xpu()) {
426 return false;
427 }
428
429 // disable NHWC for float64 input.
430 if (input.scalar_type() == at::kDouble ||
431 weight.scalar_type() == at::kDouble) {
432 return false;
433 }
434
435 auto input_memory_format = input.suggest_memory_format();
436 auto weight_memory_format = weight.suggest_memory_format();
437
438 bool can_use_xpu_channels_last_2d =
439 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
440 (weight_memory_format == at::MemoryFormat::ChannelsLast);
441
442 bool can_use_xpu_channels_last_3d =
443 (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
444 (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
445
446 return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
447 }
448
449 } // namespace at::native
450