1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/boxing/KernelFunction.h>
4 #include <ATen/native/quantized/cpu/init_qnnpack.h>
5 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
6 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_empty_affine_quantized_native.h>
12 #include <ATen/ops/channel_shuffle_native.h>
13 #endif
14
15 namespace at {
16 namespace native {
17
18 #ifdef USE_PYTORCH_QNNPACK
19 namespace {
quantized_channel_shuffle_impl(const Tensor & self,int64_t groups)20 Tensor quantized_channel_shuffle_impl(
21 const Tensor& self,
22 int64_t groups) {
23
24 TORCH_CHECK(
25 groups > 0,
26 "Number of groups to divide channels in must be positive.",
27 " Value of groups:", groups);
28 TORCH_CHECK(
29 self.dim() == 4,
30 "channel_shuffle expects 4D input, but got input with sizes ",
31 self.sizes());
32 TORCH_CHECK(
33 self.scalar_type() == kQUInt8,
34 "Quantized channel shuffle works only on ",
35 toString(c10::kQUInt8),
36 " but got ", self.scalar_type());
37 const Tensor self_nhwc = self.contiguous(MemoryFormat::ChannelsLast);
38 Tensor qy = at::native::empty_affine_quantized(
39 self_nhwc.sizes(),
40 kQUInt8,
41 std::nullopt /* layout */,
42 kCPU,
43 std::nullopt /* pin_memory */,
44 self_nhwc.q_scale(),
45 self_nhwc.q_zero_point(),
46 MemoryFormat::ChannelsLast);
47
48 // Degenerate case of just copying.
49 if (groups == 1) {
50 qy.copy_(self_nhwc);
51 return qy.contiguous(self.suggest_memory_format());
52 }
53
54 int64_t channels = self.size(1);
55 TORCH_CHECK(channels > 0,
56 "Number of channels must be positive, got:", channels);
57 TORCH_CHECK((channels % groups) == 0,
58 "Number of channels must be divisible gy groups. Got ",
59 channels, " channels and ", groups, " groups.");
60
61 initQNNPACK();
62
63 pytorch_qnnp_operator_t qnnpack_operator{nullptr};
64
65 const pytorch_qnnp_status createStatus = pytorch_qnnp_create_channel_shuffle_nc_x8(
66 groups /* groups */,
67 channels / groups /* group channels */,
68 0 /* flags */,
69 &qnnpack_operator);
70 TORCH_INTERNAL_ASSERT(
71 createStatus == pytorch_qnnp_status_success,
72 "failed to create QNNPACK ChannelShuffle operator");
73
74 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
75 qnnpack_uniq_ptr(qnnpack_operator);
76
77 const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_channel_shuffle_nc_x8(
78 qnnpack_uniq_ptr.get(),
79 self_nhwc.numel() / channels /* batch size */,
80 (uint8_t*)self_nhwc.data_ptr<c10::quint8>() /* self data */,
81 channels /* self stride */,
82 (uint8_t*)qy.data_ptr<c10::quint8>() /* qy data */,
83 channels /* qy stride */);
84 TORCH_INTERNAL_ASSERT(
85 setupStatus == pytorch_qnnp_status_success,
86 "failed to setup QNNPACK ChannelShuffle operator");
87
88 pthreadpool_t threadpool = caffe2::pthreadpool_();
89 const pytorch_qnnp_status runStatus =
90 pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
91 TORCH_INTERNAL_ASSERT(
92 runStatus == pytorch_qnnp_status_success,
93 "failed to run QNNPACK ChannelShuffle operator");
94
95 return qy.contiguous(self.suggest_memory_format());
96 }
97 } // namespace
98 #endif
99
100 // at::native functions for the native_functions.yaml
channel_shuffle_quantized_cpu(const Tensor & self,int64_t groups)101 Tensor channel_shuffle_quantized_cpu(
102 const Tensor& self,
103 int64_t groups) {
104 #ifdef USE_PYTORCH_QNNPACK
105 return quantized_channel_shuffle_impl(self, groups);
106 #endif
107 // If QNNPACK is not available then fall back to the
108 // non quantized path.
109 return at::native::channel_shuffle(self, groups);
110 }
111
112 // Keep the registry in the anonymous namespace.
113 namespace {
114 class QChannelShuffle final : public c10::OperatorKernel {
115 public:
operator ()(Tensor qx,int64_t groups)116 Tensor operator()(Tensor qx, int64_t groups) {
117 return channel_shuffle_quantized_cpu(qx, groups);
118 }
119 };
120
121 } // namespace
122
123 } // namespace native
124 } // namespace at
125