xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/boxing/KernelFunction.h>
4 #include <ATen/native/quantized/cpu/init_qnnpack.h>
5 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
6 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_empty_affine_quantized_native.h>
12 #include <ATen/ops/channel_shuffle_native.h>
13 #endif
14 
15 namespace at {
16 namespace native {
17 
18 #ifdef USE_PYTORCH_QNNPACK
19 namespace {
quantized_channel_shuffle_impl(const Tensor & self,int64_t groups)20 Tensor quantized_channel_shuffle_impl(
21     const Tensor& self,
22     int64_t groups) {
23 
24   TORCH_CHECK(
25       groups > 0,
26       "Number of groups to divide channels in must be positive.",
27       " Value of groups:", groups);
28   TORCH_CHECK(
29       self.dim() == 4,
30       "channel_shuffle expects 4D input, but got input with sizes ",
31       self.sizes());
32   TORCH_CHECK(
33       self.scalar_type() == kQUInt8,
34       "Quantized channel shuffle works only on ",
35       toString(c10::kQUInt8),
36       " but got ", self.scalar_type());
37   const Tensor self_nhwc = self.contiguous(MemoryFormat::ChannelsLast);
38   Tensor qy = at::native::empty_affine_quantized(
39       self_nhwc.sizes(),
40       kQUInt8,
41       std::nullopt /* layout */,
42       kCPU,
43       std::nullopt /* pin_memory */,
44       self_nhwc.q_scale(),
45       self_nhwc.q_zero_point(),
46       MemoryFormat::ChannelsLast);
47 
48   // Degenerate case of just copying.
49   if (groups == 1) {
50     qy.copy_(self_nhwc);
51     return qy.contiguous(self.suggest_memory_format());
52   }
53 
54   int64_t channels = self.size(1);
55   TORCH_CHECK(channels > 0,
56              "Number of channels must be positive, got:", channels);
57   TORCH_CHECK((channels % groups) == 0,
58              "Number of channels must be divisible gy groups. Got ",
59              channels, " channels and ", groups, " groups.");
60 
61   initQNNPACK();
62 
63   pytorch_qnnp_operator_t qnnpack_operator{nullptr};
64 
65   const pytorch_qnnp_status createStatus = pytorch_qnnp_create_channel_shuffle_nc_x8(
66       groups /* groups */,
67       channels / groups /* group channels */,
68       0 /* flags */,
69       &qnnpack_operator);
70   TORCH_INTERNAL_ASSERT(
71       createStatus == pytorch_qnnp_status_success,
72       "failed to create QNNPACK ChannelShuffle operator");
73 
74   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
75       qnnpack_uniq_ptr(qnnpack_operator);
76 
77   const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_channel_shuffle_nc_x8(
78       qnnpack_uniq_ptr.get(),
79       self_nhwc.numel() / channels /* batch size */,
80       (uint8_t*)self_nhwc.data_ptr<c10::quint8>() /* self data */,
81       channels /* self stride */,
82       (uint8_t*)qy.data_ptr<c10::quint8>() /* qy data */,
83       channels /* qy stride */);
84   TORCH_INTERNAL_ASSERT(
85       setupStatus == pytorch_qnnp_status_success,
86       "failed to setup QNNPACK ChannelShuffle operator");
87 
88   pthreadpool_t threadpool = caffe2::pthreadpool_();
89   const pytorch_qnnp_status runStatus =
90       pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
91   TORCH_INTERNAL_ASSERT(
92       runStatus == pytorch_qnnp_status_success,
93       "failed to run QNNPACK ChannelShuffle operator");
94 
95   return qy.contiguous(self.suggest_memory_format());
96 }
97 } // namespace
98 #endif
99 
100 // at::native functions for the native_functions.yaml
channel_shuffle_quantized_cpu(const Tensor & self,int64_t groups)101 Tensor channel_shuffle_quantized_cpu(
102     const Tensor& self,
103     int64_t groups) {
104 #ifdef USE_PYTORCH_QNNPACK
105   return quantized_channel_shuffle_impl(self, groups);
106 #endif
107   // If QNNPACK is not available then fall back to the
108   // non quantized path.
109   return at::native::channel_shuffle(self, groups);
110 }
111 
112 // Keep the registry in the anonymous namespace.
113 namespace {
114 class QChannelShuffle final : public c10::OperatorKernel {
115  public:
operator ()(Tensor qx,int64_t groups)116   Tensor operator()(Tensor qx, int64_t groups) {
117     return channel_shuffle_quantized_cpu(qx, groups);
118   }
119 };
120 
121 } // namespace
122 
123 } // namespace native
124 } // namespace at
125