1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/QScheme.h> 5 6 #ifdef USE_PYTORCH_QNNPACK 7 // TODO: Refacto QnnpackUtils.h so as to separate code 8 // needed for quantized op from the generic qnnpack specific 9 // quantization utilities. 10 #include <ATen/native/quantized/cpu/QnnpackUtils.h> 11 #include <pack_block_sparse.h> 12 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h> 13 14 namespace ao { 15 namespace sparse { 16 17 struct TORCH_API PackedLinearWeightQnnp 18 : public LinearPackedParamsBase { 19 PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional<at::Tensor>& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */); 20 explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized); 21 std::optional<at::Tensor> orig_bias_; 22 // Separate copy of bias exist so that we can fill in zeros when 23 // optional bias does not exist. This is to compy with qnnpack operator that 24 // expects bias to be present. 25 // In case bias is present bias_ is just a reference to orig_bias_ 26 at::Tensor bias_; 27 c10::QScheme q_scheme_; 28 double input_scale_; 29 std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_; 30 at::Tensor w_scales_; 31 std::vector<uint8_t> w_zero_points_; 32 std::vector<float> requantization_scales_; 33 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> 34 sparse_linear_op_{nullptr}; 35 int64_t output_channels_; 36 int64_t input_channels_; 37 // Deserialized Tensors are stored to maintain the lifetime of underlying 38 // BCSR data. 39 // These are left empty if PackedLinearWeightQnnp is created via prepacking 40 // rather than deserializing. 41 at::Tensor deserialized_bcsr_row_block_indices_; 42 at::Tensor deserialized_bcsr_col_block_indices_; 43 at::Tensor deserialized_bcsr_weight_values_; 44 applyPackedLinearWeightQnnp45 at::Tensor apply( 46 const at::Tensor& input, 47 double output_scale, 48 int64_t output_zero_point) override { 49 TORCH_CHECK( 50 false, "Static quantized sparse linear unimplemented on QNNPACK"); 51 } apply_reluPackedLinearWeightQnnp52 at::Tensor apply_relu( 53 const at::Tensor& input, 54 double output_scale, 55 int64_t output_zero_point) override { 56 TORCH_CHECK( 57 false, "Static quantized sparse linear unimplemented on QNNPACK"); 58 } 59 60 at::Tensor apply_dynamic(const at::Tensor& input) override; 61 at::Tensor apply_dynamic_relu(const at::Tensor& input) override; 62 63 LinearPackedSerializationType unpack() override; 64 65 BCSRSerializationType serialize() override; 66 67 static c10::intrusive_ptr<LinearPackedParamsBase> deserialize( 68 const BCSRSerializationType& serialized); 69 biasPackedLinearWeightQnnp70 std::optional<at::Tensor> bias() override { 71 return orig_bias_; 72 } 73 74 static c10::intrusive_ptr<LinearPackedParamsBase> prepack( 75 const at::Tensor& weight, 76 const std::optional<at::Tensor>& bias, 77 const int64_t out_features_block_size, 78 const int64_t in_features_block_size); 79 80 private: 81 template <bool ReluFused> 82 at::Tensor apply_impl( 83 const at::Tensor& input, 84 double output_scale, 85 int64_t output_zero_point); 86 template <bool ReluFused> 87 at::Tensor apply_dynamic_impl(const at::Tensor& input); 88 }; 89 90 }} // namespace ao::sparse 91 92 #endif // USE_PYTORCH_QNNPACK 93