1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/custom_class.h>
5 #include <torch/library.h>
6 #include <c10/util/accumulate.h>
7
8 #include <ATen/native/quantized/cpu/QuantUtils.h>
9 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
10
11 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
12 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/quantize_per_tensor.h>
18 #include <ATen/ops/empty.h>
19 #endif
20
21 namespace ao {
22 namespace sparse {
23
24 int register_linear_params();
25
26 #ifdef USE_PYTORCH_QNNPACK
27 template <>
apply_dynamic_impl(const at::Tensor & input)28 at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<true>(
29 const at::Tensor& input) {
30 TORCH_INTERNAL_ASSERT(
31 false,
32 "Sparse quantized dynamic linear with fused relu is not yet "
33 "supported on qnnpack backend.");
34 return at::Tensor();
35 }
36
37 template <>
apply_dynamic_impl(const at::Tensor & input)38 at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<false>(
39 const at::Tensor& input) {
40 TORCH_CHECK(
41 input.dim() >= 2,
42 "quantized_sparse_linear(): Input tensor rank should be >= 2");
43
44 const auto rows_input = c10::multiply_integers(input.sizes().begin(), input.sizes().end() - 1);
45 const auto cols_input = static_cast<int64_t>(input.size(input.dim() - 1));
46 TORCH_CHECK(
47 cols_input == input_channels_,
48 "quantized_sparse_linear: Input tensor's last and weight tensor's"
49 " second dimension must match.");
50
51 // On empty input, no output data will be generated,
52 // so use arbitrary qparams.
53 float x_min = 0;
54 float x_max = 0;
55 // Otherwise...
56 if (input.numel() > 0) {
57 x_min = input.min().item<float>();
58 x_max = input.max().item<float>();
59 }
60
61 auto q_params = quant_utils::ChooseQuantizationParams(
62 /*min=*/x_min,
63 /*max=*/x_max,
64 /*qmin=*/0,
65 /*qmax=*/255);
66
67 // Quantize input
68 at::Tensor q_input = at::quantize_per_tensor(
69 input, q_params.scale, q_params.zero_point, c10::kQUInt8);
70
71 auto q_input_contig = q_input.contiguous();
72 if (sparse_linear_op_ == nullptr) {
73 // We calculate requant scale here as the vector holding the requant scale
74 // is owned by this module. The pointer is then passed to qnnpack backend.
75 generate_requantization_scales(
76 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
77 w_scales_, q_input_contig.q_scale(), 1.f, requantization_scales_);
78 input_scale_ = q_input_contig.q_scale();
79 pytorch_qnnp_operator_t sparse_linear_op{nullptr};
80 pytorch_qnnp_status status =
81 pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
82 input_channels_,
83 output_channels_,
84 q_input_contig.q_zero_point(),
85 w_zero_points_.data(),
86 bcsr_matrix_->col_indices_data_ptr(),
87 bcsr_matrix_->row_values_data_ptr(),
88 bcsr_matrix_->values.data(),
89 bcsr_matrix_->row_block_size, /* out_features_block_size */
90 bcsr_matrix_->col_block_size, /* in_features_block_size */
91 bcsr_matrix_->indices_dtype,
92 0, /* output zero point: not used */
93 std::numeric_limits<uint8_t>::min(),
94 std::numeric_limits<uint8_t>::max(),
95 0, /* flags */
96 requantization_scales_.data(),
97 true, /* use prepacking kernel */
98 &sparse_linear_op);
99 TORCH_CHECK(
100 status == pytorch_qnnp_status_success,
101 "Failed to create sparse linear operator on"
102 " qnnpack backend.");
103 sparse_linear_op_ =
104 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
105 sparse_linear_op);
106 }
107
108 // Input on next iteration can be different, thus resulting in
109 // different input scale. This will require us to recalculate requantization
110 // scales.
111 if (input_scale_ != q_input_contig.q_scale()) {
112 generate_requantization_scales(
113 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
114 w_scales_, q_input_contig.q_scale(), 1.f, requantization_scales_);
115 }
116 // Update input related quantization params in the operator.
117 sparse_linear_op_->dynamic_conv_quantization_params.input_zero_point =
118 q_input_contig.q_zero_point();
119 sparse_linear_op_->dynamic_conv_quantization_params.multipliers =
120 requantization_scales_.data();
121
122 std::vector<int64_t> out_sizes = input.sizes().vec();
123 out_sizes.back() = output_channels_;
124
125 auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
126
127 pytorch_qnnp_status status =
128 pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(
129 sparse_linear_op_.get(),
130 rows_input, /* batch size */
131 reinterpret_cast<uint8_t*>(q_input_contig.data_ptr<c10::quint8>()),
132 cols_input, /* num input channels */
133 bias_.data_ptr<float>(),
134 output.data_ptr<float>(),
135 output_channels_);
136 TORCH_CHECK(
137 status == pytorch_qnnp_status_success,
138 "Failed to setup sparse linear operator on"
139 " qnnpack backend.");
140
141 status = pytorch_qnnp_run_operator(
142 sparse_linear_op_.get(), caffe2::pthreadpool_());
143 TORCH_CHECK(
144 status == pytorch_qnnp_status_success,
145 "Failed to run sparse linear operator on"
146 " qnnpack backend.");
147
148 return output;
149 }
150
apply_dynamic(const at::Tensor & input)151 at::Tensor PackedLinearWeightQnnp::apply_dynamic(
152 const at::Tensor& input) {
153 return apply_dynamic_impl<false>(input);
154 }
155
apply_dynamic_relu(const at::Tensor & input)156 at::Tensor PackedLinearWeightQnnp::apply_dynamic_relu(
157 const at::Tensor& input) {
158 return apply_dynamic_impl<true>(input);
159 }
160
161 #endif // USE_PYTORCH_QNNPACK
162
163 namespace {
164
165 template <bool ReluFused>
166 class QLinearDynamicInt8 final {
167 public:
run(const at::Tensor & input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)168 static at::Tensor run(
169 const at::Tensor& input,
170 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
171 auto& ctx = at::globalContext();
172 #ifdef USE_PYTORCH_QNNPACK
173 if (ctx.qEngine() == at::QEngine::QNNPACK) {
174 if (ReluFused) {
175 return packed_weight->apply_dynamic_relu(input);
176 } else {
177 return packed_weight->apply_dynamic(input);
178 }
179 }
180 #endif
181 TORCH_CHECK(
182 false,
183 "Didn't find engine for operation ao::sparse::qlinear_dynamic",
184 toString(ctx.qEngine()));
185 }
186 };
187
TORCH_LIBRARY_IMPL(sparse,CPU,m)188 TORCH_LIBRARY_IMPL(sparse, CPU, m) {
189 m.impl(
190 TORCH_SELECTIVE_NAME("sparse::qlinear_dynamic"),
191 TORCH_FN(QLinearDynamicInt8<false>::run));
192 m.impl(
193 TORCH_SELECTIVE_NAME("sparse::qlinear_relu_dynamic"),
194 TORCH_FN(QLinearDynamicInt8<true>::run));
195 }
196
197 } // namespace
198 }} // namespace ao::sparse
199