1 #ifdef USE_XNNPACK
2
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/native/utils/Factory.h>
6
7 namespace at::native::xnnpack {
8
use_channel_shuffle(const Tensor & input,const int64_t groups)9 bool use_channel_shuffle(
10 const Tensor& input,
11 const int64_t groups) {
12 using namespace internal;
13
14 // Here are the list of conditions required for this code path to be taken:
15 // * Input must be 4D CPU float tensor with no gradients and
16 // and all dimensions must be positive.
17 // * The number of groups must be larger than 1 and
18 // the number of channels must be divisible by the number of groups.
19 return xnnpack::available() &&
20 // Input
21 (4 == input.dim()) &&
22 (input.device().is_cpu()) &&
23 (kFloat == input.scalar_type()) &&
24 (input.size(Layout::Activation4D::batch) >= 0) &&
25 (input.size(Layout::Activation4D::channels) > 0) &&
26 (input.size(Layout::Activation4D::height) > 0) &&
27 (input.size(Layout::Activation4D::width) > 0) &&
28 !input.requires_grad() &&
29 // Groups
30 groups > 1 &&
31 (0 == input.size(Layout::Activation4D::channels) % groups) &&
32 true;
33 }
34
channel_shuffle(const Tensor & input,const int64_t groups)35 Tensor channel_shuffle(
36 const Tensor& input,
37 const int64_t groups) {
38 using namespace internal;
39
40 // A call to channel_shuffle must have been gated by a call to use_channel_shuffle,
41 // so the parameters are guaranteed to be valid at this point.
42
43 const Tensor input_padded_contig_nhwc =
44 mobile::allocate_padded_contiguous_if_needed(
45 input,
46 MemoryFormat::ChannelsLast);
47
48 Tensor output_padded_contig_nhwc = mobile::empty_with_tail_padding(
49 {
50 input_padded_contig_nhwc.size(Layout::Activation4D::batch),
51 input_padded_contig_nhwc.size(Layout::Activation4D::channels),
52 input_padded_contig_nhwc.size(Layout::Activation4D::height),
53 input_padded_contig_nhwc.size(Layout::Activation4D::width),
54 },
55 input_padded_contig_nhwc.options().dtype(),
56 MemoryFormat::ChannelsLast,
57 input_padded_contig_nhwc.opt_names());
58
59 int64_t channels_per_group =
60 input_padded_contig_nhwc.size(Layout::Activation4D::channels) / groups;
61
62 xnn_operator_t channel_shuffle_op{};
63
64 const xnn_status create_status = xnn_create_channel_shuffle_nc_x32(
65 groups, // number of groups
66 channels_per_group, // number of channels per group
67 input_padded_contig_nhwc.size(Layout::Activation4D::channels), // input_pixel_stride - NHWC Contiguous
68 output_padded_contig_nhwc.size(Layout::Activation4D::channels), // output_pixel_stride - NHWC Contiguous
69 0u, // flags
70 &channel_shuffle_op); // operator
71
72 Operator channel_shuffle_scoped_op(channel_shuffle_op);
73
74 TORCH_CHECK(
75 xnn_status_success == create_status,
76 "xnn_create_channel_shuffle_nc_x32 failed!");
77
78 int64_t batch_size = input_padded_contig_nhwc.size(Layout::Activation4D::batch) *
79 input_padded_contig_nhwc.size(Layout::Activation4D::height) *
80 input_padded_contig_nhwc.size(Layout::Activation4D::width);
81
82 const xnn_status reshape_status = xnn_reshape_channel_shuffle_nc_x32(
83 channel_shuffle_op, // operator
84 batch_size, // batch_size
85 caffe2::pthreadpool_()); // threadpool
86
87 TORCH_CHECK(
88 xnn_status_success == reshape_status,
89 "xnn_reshape_channel_shuffle_nc_x32 failed!");
90
91 const xnn_status setup_status = xnn_setup_channel_shuffle_nc_x32(
92 channel_shuffle_op, // operator
93 input_padded_contig_nhwc.data_ptr<float>(), // input
94 output_padded_contig_nhwc.data_ptr<float>()); // output
95
96 TORCH_CHECK(
97 xnn_status_success == setup_status,
98 "xnn_setup_channel_shuffle_nc_x32 failed!");
99
100 const xnn_status run_status = xnn_run_operator(
101 channel_shuffle_op, // operator
102 caffe2::pthreadpool_()); // threadpool
103
104 TORCH_INTERNAL_ASSERT(
105 xnn_status_success == run_status,
106 "xnn_run_operator failed!");
107
108 return output_padded_contig_nhwc.contiguous(input.suggest_memory_format());
109 }
110
111 } // namespace at::native::xnnpack
112
113 #endif /* USE_XNNPACK */
114