xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/custom_class.h>
5 #include <torch/library.h>
6 #include <c10/util/accumulate.h>
7 
8 #include <ATen/native/quantized/cpu/QuantUtils.h>
9 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
10 
11 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
12 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/quantize_per_tensor.h>
18 #include <ATen/ops/empty.h>
19 #endif
20 
21 namespace ao {
22 namespace sparse {
23 
24 int register_linear_params();
25 
26 #ifdef USE_PYTORCH_QNNPACK
27 template <>
apply_dynamic_impl(const at::Tensor & input)28 at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<true>(
29     const at::Tensor& input) {
30   TORCH_INTERNAL_ASSERT(
31       false,
32       "Sparse quantized dynamic linear with fused relu is not yet "
33       "supported on qnnpack backend.");
34   return at::Tensor();
35 }
36 
37 template <>
apply_dynamic_impl(const at::Tensor & input)38 at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<false>(
39     const at::Tensor& input) {
40   TORCH_CHECK(
41       input.dim() >= 2,
42       "quantized_sparse_linear(): Input tensor rank should be >= 2");
43 
44   const auto rows_input = c10::multiply_integers(input.sizes().begin(), input.sizes().end() - 1);
45   const auto cols_input = static_cast<int64_t>(input.size(input.dim() - 1));
46   TORCH_CHECK(
47       cols_input == input_channels_,
48       "quantized_sparse_linear: Input tensor's last and weight tensor's"
49       " second dimension must match.");
50 
51   // On empty input, no output data will be generated,
52   // so use arbitrary qparams.
53   float x_min = 0;
54   float x_max = 0;
55   // Otherwise...
56   if (input.numel() > 0) {
57     x_min = input.min().item<float>();
58     x_max = input.max().item<float>();
59   }
60 
61   auto q_params = quant_utils::ChooseQuantizationParams(
62       /*min=*/x_min,
63       /*max=*/x_max,
64       /*qmin=*/0,
65       /*qmax=*/255);
66 
67   // Quantize input
68   at::Tensor q_input = at::quantize_per_tensor(
69       input, q_params.scale, q_params.zero_point, c10::kQUInt8);
70 
71   auto q_input_contig = q_input.contiguous();
72   if (sparse_linear_op_ == nullptr) {
73     // We calculate requant scale here as the vector holding the requant scale
74     // is owned by this module. The pointer is then passed to qnnpack backend.
75     generate_requantization_scales(
76         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
77         w_scales_, q_input_contig.q_scale(), 1.f, requantization_scales_);
78     input_scale_ = q_input_contig.q_scale();
79     pytorch_qnnp_operator_t sparse_linear_op{nullptr};
80     pytorch_qnnp_status status =
81         pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
82             input_channels_,
83             output_channels_,
84             q_input_contig.q_zero_point(),
85             w_zero_points_.data(),
86             bcsr_matrix_->col_indices_data_ptr(),
87             bcsr_matrix_->row_values_data_ptr(),
88             bcsr_matrix_->values.data(),
89             bcsr_matrix_->row_block_size, /* out_features_block_size */
90             bcsr_matrix_->col_block_size, /* in_features_block_size */
91             bcsr_matrix_->indices_dtype,
92             0, /* output zero point: not used */
93             std::numeric_limits<uint8_t>::min(),
94             std::numeric_limits<uint8_t>::max(),
95             0, /* flags */
96             requantization_scales_.data(),
97             true, /* use prepacking kernel */
98             &sparse_linear_op);
99     TORCH_CHECK(
100         status == pytorch_qnnp_status_success,
101         "Failed to create sparse linear operator on"
102         " qnnpack backend.");
103     sparse_linear_op_ =
104         std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
105             sparse_linear_op);
106   }
107 
108   // Input on next iteration can be different, thus resulting in
109   // different input scale. This will require us to recalculate requantization
110   // scales.
111   if (input_scale_ != q_input_contig.q_scale()) {
112     generate_requantization_scales(
113         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
114         w_scales_, q_input_contig.q_scale(), 1.f, requantization_scales_);
115   }
116   // Update input related quantization params in the operator.
117   sparse_linear_op_->dynamic_conv_quantization_params.input_zero_point =
118       q_input_contig.q_zero_point();
119   sparse_linear_op_->dynamic_conv_quantization_params.multipliers =
120       requantization_scales_.data();
121 
122   std::vector<int64_t> out_sizes = input.sizes().vec();
123   out_sizes.back() = output_channels_;
124 
125   auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
126 
127   pytorch_qnnp_status status =
128       pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(
129           sparse_linear_op_.get(),
130           rows_input, /* batch size */
131           reinterpret_cast<uint8_t*>(q_input_contig.data_ptr<c10::quint8>()),
132           cols_input, /* num input channels */
133           bias_.data_ptr<float>(),
134           output.data_ptr<float>(),
135           output_channels_);
136   TORCH_CHECK(
137       status == pytorch_qnnp_status_success,
138       "Failed to setup sparse linear operator on"
139       " qnnpack backend.");
140 
141   status = pytorch_qnnp_run_operator(
142       sparse_linear_op_.get(), caffe2::pthreadpool_());
143   TORCH_CHECK(
144       status == pytorch_qnnp_status_success,
145       "Failed to run sparse linear operator on"
146       " qnnpack backend.");
147 
148   return output;
149 }
150 
apply_dynamic(const at::Tensor & input)151 at::Tensor PackedLinearWeightQnnp::apply_dynamic(
152     const at::Tensor& input) {
153   return apply_dynamic_impl<false>(input);
154 }
155 
apply_dynamic_relu(const at::Tensor & input)156 at::Tensor PackedLinearWeightQnnp::apply_dynamic_relu(
157     const at::Tensor& input) {
158   return apply_dynamic_impl<true>(input);
159 }
160 
161 #endif // USE_PYTORCH_QNNPACK
162 
163 namespace {
164 
165 template <bool ReluFused>
166 class QLinearDynamicInt8 final {
167  public:
run(const at::Tensor & input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)168   static at::Tensor run(
169       const at::Tensor& input,
170       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
171     auto& ctx = at::globalContext();
172 #ifdef USE_PYTORCH_QNNPACK
173     if (ctx.qEngine() == at::QEngine::QNNPACK) {
174       if (ReluFused) {
175         return packed_weight->apply_dynamic_relu(input);
176       } else {
177         return packed_weight->apply_dynamic(input);
178       }
179     }
180 #endif
181     TORCH_CHECK(
182         false,
183         "Didn't find engine for operation ao::sparse::qlinear_dynamic",
184         toString(ctx.qEngine()));
185   }
186 };
187 
TORCH_LIBRARY_IMPL(sparse,CPU,m)188 TORCH_LIBRARY_IMPL(sparse, CPU, m) {
189   m.impl(
190       TORCH_SELECTIVE_NAME("sparse::qlinear_dynamic"),
191       TORCH_FN(QLinearDynamicInt8<false>::run));
192   m.impl(
193       TORCH_SELECTIVE_NAME("sparse::qlinear_relu_dynamic"),
194       TORCH_FN(QLinearDynamicInt8<true>::run));
195 }
196 
197 } // namespace
198 }} // namespace ao::sparse
199