xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/QScheme.h>
5 
6 #ifdef USE_PYTORCH_QNNPACK
7 // TODO: Refacto QnnpackUtils.h so as to separate code
8 // needed for quantized op from the generic qnnpack specific
9 // quantization utilities.
10 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
11 #include <pack_block_sparse.h>
12 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
13 
14 namespace ao {
15 namespace sparse {
16 
17 struct TORCH_API PackedLinearWeightQnnp
18     : public LinearPackedParamsBase {
19   PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional<at::Tensor>& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */);
20   explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized);
21   std::optional<at::Tensor> orig_bias_;
22   // Separate copy of bias exist so that we can fill in zeros when
23   // optional bias does not exist. This is to compy with qnnpack operator that
24   // expects bias to be present.
25   // In case bias is present bias_ is just a reference to orig_bias_
26   at::Tensor bias_;
27   c10::QScheme q_scheme_;
28   double input_scale_;
29   std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_;
30   at::Tensor w_scales_;
31   std::vector<uint8_t> w_zero_points_;
32   std::vector<float> requantization_scales_;
33   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
34       sparse_linear_op_{nullptr};
35   int64_t output_channels_;
36   int64_t input_channels_;
37   // Deserialized Tensors are stored to maintain the lifetime of underlying
38   // BCSR data.
39   // These are left empty if PackedLinearWeightQnnp is created via prepacking
40   // rather than deserializing.
41   at::Tensor deserialized_bcsr_row_block_indices_;
42   at::Tensor deserialized_bcsr_col_block_indices_;
43   at::Tensor deserialized_bcsr_weight_values_;
44 
applyPackedLinearWeightQnnp45   at::Tensor apply(
46       const at::Tensor& input,
47       double output_scale,
48       int64_t output_zero_point) override {
49     TORCH_CHECK(
50         false, "Static quantized sparse linear unimplemented on QNNPACK");
51   }
apply_reluPackedLinearWeightQnnp52   at::Tensor apply_relu(
53       const at::Tensor& input,
54       double output_scale,
55       int64_t output_zero_point) override {
56     TORCH_CHECK(
57         false, "Static quantized sparse linear unimplemented on QNNPACK");
58   }
59 
60   at::Tensor apply_dynamic(const at::Tensor& input) override;
61   at::Tensor apply_dynamic_relu(const at::Tensor& input) override;
62 
63   LinearPackedSerializationType unpack() override;
64 
65   BCSRSerializationType serialize() override;
66 
67   static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
68       const BCSRSerializationType& serialized);
69 
biasPackedLinearWeightQnnp70   std::optional<at::Tensor> bias() override {
71     return orig_bias_;
72   }
73 
74   static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
75       const at::Tensor& weight,
76       const std::optional<at::Tensor>& bias,
77       const int64_t out_features_block_size,
78       const int64_t in_features_block_size);
79 
80  private:
81   template <bool ReluFused>
82   at::Tensor apply_impl(
83       const at::Tensor& input,
84       double output_scale,
85       int64_t output_zero_point);
86   template <bool ReluFused>
87   at::Tensor apply_dynamic_impl(const at::Tensor& input);
88 };
89 
90 }}  // namespace ao::sparse
91 
92 #endif // USE_PYTORCH_QNNPACK
93