xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/ChannelShuffle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/native/utils/Factory.h>
6 
7 namespace at::native::xnnpack {
8 
use_channel_shuffle(const Tensor & input,const int64_t groups)9 bool use_channel_shuffle(
10     const Tensor& input,
11     const int64_t groups) {
12   using namespace internal;
13 
14   // Here are the list of conditions required for this code path to be taken:
15   // * Input must be 4D CPU float tensor with no gradients and
16   //   and all dimensions must be positive.
17   // * The number of groups must be larger than 1 and
18   //   the number of channels must be divisible by the number of groups.
19   return xnnpack::available() &&
20       // Input
21       (4 == input.dim()) &&
22       (input.device().is_cpu()) &&
23       (kFloat == input.scalar_type()) &&
24       (input.size(Layout::Activation4D::batch) >= 0) &&
25       (input.size(Layout::Activation4D::channels) > 0) &&
26       (input.size(Layout::Activation4D::height) > 0) &&
27       (input.size(Layout::Activation4D::width) > 0) &&
28       !input.requires_grad() &&
29       // Groups
30       groups > 1 &&
31       (0 == input.size(Layout::Activation4D::channels) % groups) &&
32       true;
33 }
34 
channel_shuffle(const Tensor & input,const int64_t groups)35 Tensor channel_shuffle(
36     const Tensor& input,
37     const int64_t groups) {
38   using namespace internal;
39 
40   // A call to channel_shuffle must have been gated by a call to use_channel_shuffle,
41   // so the parameters are guaranteed to be valid at this point.
42 
43   const Tensor input_padded_contig_nhwc =
44       mobile::allocate_padded_contiguous_if_needed(
45           input,
46           MemoryFormat::ChannelsLast);
47 
48   Tensor output_padded_contig_nhwc = mobile::empty_with_tail_padding(
49       {
50         input_padded_contig_nhwc.size(Layout::Activation4D::batch),
51         input_padded_contig_nhwc.size(Layout::Activation4D::channels),
52         input_padded_contig_nhwc.size(Layout::Activation4D::height),
53         input_padded_contig_nhwc.size(Layout::Activation4D::width),
54       },
55       input_padded_contig_nhwc.options().dtype(),
56       MemoryFormat::ChannelsLast,
57       input_padded_contig_nhwc.opt_names());
58 
59   int64_t channels_per_group =
60       input_padded_contig_nhwc.size(Layout::Activation4D::channels) / groups;
61 
62   xnn_operator_t channel_shuffle_op{};
63 
64   const xnn_status create_status = xnn_create_channel_shuffle_nc_x32(
65       groups,                                                         // number of groups
66       channels_per_group,                                             // number of channels per group
67       input_padded_contig_nhwc.size(Layout::Activation4D::channels),  // input_pixel_stride - NHWC Contiguous
68       output_padded_contig_nhwc.size(Layout::Activation4D::channels), // output_pixel_stride - NHWC Contiguous
69       0u,                                                             // flags
70       &channel_shuffle_op);                                           // operator
71 
72   Operator channel_shuffle_scoped_op(channel_shuffle_op);
73 
74   TORCH_CHECK(
75       xnn_status_success == create_status,
76       "xnn_create_channel_shuffle_nc_x32 failed!");
77 
78   int64_t batch_size = input_padded_contig_nhwc.size(Layout::Activation4D::batch) *
79                        input_padded_contig_nhwc.size(Layout::Activation4D::height) *
80                        input_padded_contig_nhwc.size(Layout::Activation4D::width);
81 
82   const xnn_status reshape_status = xnn_reshape_channel_shuffle_nc_x32(
83       channel_shuffle_op,                                           // operator
84       batch_size,                                                   // batch_size
85       caffe2::pthreadpool_());                                      // threadpool
86 
87   TORCH_CHECK(
88       xnn_status_success == reshape_status,
89       "xnn_reshape_channel_shuffle_nc_x32 failed!");
90 
91   const xnn_status setup_status = xnn_setup_channel_shuffle_nc_x32(
92       channel_shuffle_op,                                           // operator
93       input_padded_contig_nhwc.data_ptr<float>(),                   // input
94       output_padded_contig_nhwc.data_ptr<float>());                 // output
95 
96   TORCH_CHECK(
97       xnn_status_success == setup_status,
98       "xnn_setup_channel_shuffle_nc_x32 failed!");
99 
100   const xnn_status run_status = xnn_run_operator(
101       channel_shuffle_op,       // operator
102       caffe2::pthreadpool_());  // threadpool
103 
104   TORCH_INTERNAL_ASSERT(
105       xnn_status_success == run_status,
106       "xnn_run_operator failed!");
107 
108   return output_padded_contig_nhwc.contiguous(input.suggest_memory_format());
109 }
110 
111 } // namespace at::native::xnnpack
112 
113 #endif /* USE_XNNPACK */
114