xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qclamp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <torch/library.h>
6 #include <ATen/native/quantized/AffineQuantizerBase.h>
7 #include <ATen/native/quantized/cpu/QuantizedOps.h>
8 #include <ATen/native/quantized/cpu/init_qnnpack.h>
9 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
10 #include <c10/util/irange.h>
11 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_empty_affine_quantized.h>
18 #include <ATen/ops/clamp_native.h>
19 #include <ATen/ops/hardtanh_native.h>
20 #endif
21 
22 #include <algorithm>
23 
24 namespace at {
25 namespace native {
26 
27 DEFINE_DISPATCH(qclamp_stub);
28 DEFINE_DISPATCH(qclamp_min_stub);
29 DEFINE_DISPATCH(qclamp_max_stub);
30 
31 namespace {
32 
33 #ifdef USE_PYTORCH_QNNPACK
qnnpack_clamp(Tensor input,const Scalar & min,const Scalar & max)34 Tensor qnnpack_clamp(Tensor input, const Scalar& min, const Scalar& max) {
35 
36   TORCH_CHECK(input.ndimension() > 0, "qnnpack_clamp(): Got empty input tensor");
37 
38   initQNNPACK();
39 
40   Tensor input_contig = input.contiguous(input.suggest_memory_format());
41   size_t num_elems = 1;
42   for (const auto i : c10::irange(1, input_contig.ndimension())) {
43     num_elems *= input_contig.size(i);
44   }
45 
46   auto min_f = min.to<float>();
47   auto max_f = max.to<float>();
48   uint8_t min_q =
49       at::native::quantize_val<quint8>(input.q_scale(), input.q_zero_point(), min_f).val_;
50   uint8_t max_q =
51       at::native::quantize_val<quint8>(input.q_scale(), input.q_zero_point(), max_f).val_;
52 
53   pytorch_qnnp_operator_t clamp_op{nullptr};
54   const pytorch_qnnp_status createStatus = pytorch_qnnp_create_clamp_nc_u8(
55     num_elems, // channels
56     min_q,
57     max_q,
58     0, // flags
59     &clamp_op);
60 
61   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
62       qnnpack_uniq_ptr(clamp_op);
63 
64   TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
65                         "failed to create QNNPACK Clamp operator");
66 
67   Tensor qy = at::_empty_affine_quantized(
68     input_contig.sizes(),
69     input_contig.options(),
70     input_contig.q_scale(),
71     input_contig.q_zero_point());
72 
73   const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_clamp_nc_u8(
74     clamp_op,
75     input_contig.size(0), // batch_size
76     (uint8_t*)input_contig.data_ptr<c10::quint8>(), // input_data
77     num_elems, // input_stride
78     (uint8_t*)qy.data_ptr<c10::quint8>(), // output_data
79     num_elems); // output_stride
80   TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
81                         "failed to setup QNNPACK Clamp operator");
82 
83   pthreadpool_t threadpool = caffe2::pthreadpool_();
84 
85   const pytorch_qnnp_status runStatus =
86     pytorch_qnnp_run_operator(clamp_op, threadpool);
87 
88   TORCH_INTERNAL_ASSERT(
89     runStatus == pytorch_qnnp_status_success,
90     "failed to run QNNPACK Clamp operator");
91   return qy;
92 }
93 
94 #endif // USE_PYTORCH_QNNPACK
95 
quantized_clamp_impl(const Tensor & qx,const std::optional<Scalar> & min,const std::optional<Scalar> & max)96 Tensor quantized_clamp_impl(
97     const Tensor& qx,
98     const std::optional<Scalar>& min,
99     const std::optional<Scalar>& max) {
100   Tensor qy;
101   if (min && max) {
102 #ifdef USE_PYTORCH_QNNPACK
103     if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
104         qx.scalar_type() == kQUInt8) {
105       return qnnpack_clamp(qx, *min, *max);
106     }
107 #endif
108     qclamp_stub(qx.device().type(), qx, *min, *max, qy);
109   } else {
110 #ifdef USE_PYTORCH_QNNPACK
111     if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
112       TORCH_CHECK(
113           false, "Both min and max should be specified for quantized clamp!");
114     }
115 #endif
116     if (max) {
117       qclamp_max_stub(qx.device().type(), qx, *max, qy);
118     } else if (min) {
119       qclamp_min_stub(qx.device().type(), qx, *min, qy);
120     } else {
121       TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
122     }
123   }
124   return qy;
125 }
126 } // namespace
127 
128 // at::native functions for the native_functions.yaml
clamp_quantized_cpu(const Tensor & qx,const std::optional<Scalar> & min,const std::optional<Scalar> & max)129 Tensor clamp_quantized_cpu(
130     const Tensor& qx,
131     const std::optional<Scalar>& min,
132     const std::optional<Scalar>& max) {
133   Tensor qy;
134   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "clamp", [&]() {
135     qy = quantized_clamp_impl(qx, min, max);
136   });
137   return qy;
138 }
139 
140 // hardtanh is clamp with default min==-1.0f and default max==1.0f
hardtanh_quantized_cpu(const Tensor & qx,const Scalar & min,const Scalar & max)141 Tensor hardtanh_quantized_cpu(
142     const Tensor& qx,
143     const Scalar& min,
144     const Scalar& max) {
145   Tensor qy;
146   qy = quantized_clamp_impl(qx, min, max);
147   return qy;
148 }
149 
hardtanh_out_quantized_cpu(const Tensor & qx,const Scalar & min,const Scalar & max,Tensor & result)150 Tensor& hardtanh_out_quantized_cpu(const Tensor& qx,
151     const Scalar& min,
152     const Scalar& max,
153     Tensor& result) {
154   result = quantized_clamp_impl(qx, min, max);
155   return result;
156 }
157 
hardtanh_quantized_cpu_(Tensor & self,const Scalar & min,const Scalar & max)158 Tensor& hardtanh_quantized_cpu_(
159     Tensor& self,
160     const Scalar& min,
161     const Scalar& max) {
162   Tensor qy;
163   qy = quantized_clamp_impl(self, min, max);
164   // This can be optimized in a future PR if it becomes a bottleneck.
165   self.copy_(qy);
166   return self;
167 }
168 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)169 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
170   m.impl(TORCH_SELECTIVE_NAME("quantized::clamp"), TORCH_FN(clamp_quantized_cpu));
171 }
172 
173 } // namespace native
174 } // namespace at
175