xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <torch/library.h>
6 #include <ATen/native/quantized/cpu/QuantizedOps.h>
7 #include <ATen/native/quantized/cpu/init_qnnpack.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <c10/util/irange.h>
10 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_empty_affine_quantized.h>
17 #include <ATen/ops/sigmoid_native.h>
18 #endif
19 
20 #include <algorithm>
21 #include <utility>
22 
23 namespace at {
24 namespace native {
25 
26 DEFINE_DISPATCH(qsigmoid_stub);
27 
28 #ifdef USE_PYTORCH_QNNPACK
qnnpack_sigmoid(Tensor input,double output_scale,int64_t output_zero_point)29 static Tensor qnnpack_sigmoid(
30     Tensor input, double output_scale, int64_t output_zero_point) {
31   TORCH_CHECK(input.ndimension() > 0, "qnnpack_sigmoid(): Got empty input tensor");
32   TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
33                "qnnpack_sigmoid(): Expected input data type ",
34                toString(c10::kQUInt8),
35                " but got ",
36                toString(input.scalar_type()));
37 
38   Tensor qy;
39   initQNNPACK();
40 
41   Tensor input_contig = input.contiguous(input.suggest_memory_format());
42   size_t num_elems = 1;
43   for (const auto i : c10::irange(1, input_contig.ndimension())) {
44     num_elems *= input_contig.size(i);
45   }
46 
47   const auto zero_point = input_contig.q_zero_point();
48   const auto scale = input_contig.q_scale();
49 
50   pytorch_qnnp_operator_t sigmoid_op{nullptr};
51   const pytorch_qnnp_status createStatus = pytorch_qnnp_create_sigmoid_nc_q8(
52     num_elems /* channels */,
53     zero_point /* input zero point */,
54     scale /* input scale */,
55     output_zero_point /* output zero point */,
56     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
57     output_scale /* output scale */,
58     std::numeric_limits<uint8_t>::min() /* output min */,
59     std::numeric_limits<uint8_t>::max() /* output max */,
60     0 /* flags */,
61     &sigmoid_op);
62 
63   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
64       qnnpack_uniq_ptr(sigmoid_op);
65 
66   TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
67                         "failed to create QNNPACK sigmoid operator");
68   qy = at::_empty_affine_quantized(
69     input_contig.sizes(),
70     at::device(kCPU).dtype(input_contig.dtype()),
71     output_scale,
72     output_zero_point,
73     input_contig.suggest_memory_format());
74 
75   const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_sigmoid_nc_q8(
76     sigmoid_op,
77     input_contig.size(0) /* batch size */,
78     (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input data */,
79     num_elems /* input stride */,
80     (uint8_t*)qy.data_ptr<c10::quint8>() /* output data */,
81     num_elems /* output stride */);
82   TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
83                         "failed to setup QNNPACK sigmoid operator");
84 
85   pthreadpool_t threadpool = caffe2::pthreadpool_();
86 
87   const pytorch_qnnp_status runStatus =
88     pytorch_qnnp_run_operator(sigmoid_op, threadpool);
89 
90   TORCH_INTERNAL_ASSERT(
91     runStatus == pytorch_qnnp_status_success,
92     "failed to run QNNPACK sigmoid operator");
93   return qy;
94 }
95 
96 #endif  // USE_PYTORCH_QNNPACK
97 
98 // This ALWAYS outputs scale=1.0/256, dtype=quint8
99 // The zero_point is 0 for qint32 and quint8, but -128 for qint8.
sigmoid_quantized_cpu(const Tensor & qx)100 Tensor sigmoid_quantized_cpu(const Tensor& qx) {
101 #ifdef USE_PYTORCH_QNNPACK
102   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
103       qx.scalar_type() == kQUInt8) {
104     constexpr double output_scale = 1.0f / 256.0f;
105     constexpr int64_t output_zero_point = 0;
106     return qnnpack_sigmoid(qx, output_scale, output_zero_point);
107   }
108 #endif  // USE_PYTORCH_QNNPACK
109   Tensor qy;
110   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() {
111     // Naive implementation: uses dequantize/execute/quantize routine
112     // - Output scale is set to 1.0 / 2^(BIT_NUM)
113     // - For signed types output zero point is set to 0
114     // - For unsigned types output zero point is set to (qmax + qmin) / 2.0
115     // See https://stackoverflow.com/a/34448562/3606192 for potential
116     // optimizations
117     double output_scale = 0.00390625;  // 1.0 / 2^8
118     int64_t output_zero_point = 0;
119     // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
120     if (SCALAR_TYPE == at::kQInt32) {
121       output_scale = 2.3283064365386963e-10;  // 1.0 / 2^32
122     } else if (SCALAR_TYPE == at::kQInt8) {
123       output_zero_point = -128;
124     }
125     qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point);
126   });
127   return qy;
128 }
129 
130 namespace {
131 
132 class QSigmoid final {
133  public:
run(Tensor qx,double output_scale,int64_t output_zero_point)134   static Tensor run(Tensor qx, double output_scale, int64_t output_zero_point) {
135 #ifdef USE_PYTORCH_QNNPACK
136   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
137       qx.scalar_type() == kQUInt8) {
138     return qnnpack_sigmoid(std::move(qx), output_scale, output_zero_point);
139   }
140 #endif  // USE_PYTORCH_QNNPACK
141   Tensor qy;
142   qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point);
143   return qy;
144   }
145 };
146 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)147 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
148   m.impl(TORCH_SELECTIVE_NAME("quantized::sigmoid"), TORCH_FN(QSigmoid::run));
149 }
150 } // namespace
151 
152 }}  // namespace at::native
153