xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 
5 #include <ATen/core/ivalue.h>
6 
7 namespace ao {
8 namespace sparse {
9 
10 // <Weight, bias, out_features_block_size, in_features_block_size>
11 using LinearPackedSerializationType =
12     std::tuple<at::Tensor, std::optional<at::Tensor>, std::vector<int64_t>>;
13 
14 #define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2
15 
16 using BCSRSerializationType =
17     std::tuple<
18         int64_t,                    // Serialization Version
19         std::optional<at::Tensor>,  // Bias
20         int64_t,                    // Out Features (Row) Block Size
21         int64_t,                    // In Features (Column) Block Size
22         at::Tensor,                 // Weight Scales (single element vector if per-tensor) (float)
23         at::Tensor,                 // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t)
24         bool,                       // Quantization Scheme (true: per tensor, false: per channel)
25         at::Tensor,                 // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t)
26         at::Tensor,                 // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t)
27         at::Tensor,                 // Wrapper for Non-Zero Weight Values, each +128 (uint8_t)
28         int64_t,                    // Number of Output Channels
29         int64_t                     // Number of Input Channels
30     >;
31 
32 using BCSR =
33     std::tuple<
34         std::vector<int8_t>,    // Non-Zero Weight Values
35         std::vector<int32_t>,   // Compressed Row Block Indices
36         std::vector<int32_t>    // Column Block Indices
37     >;
38 
39 struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
40  public:
LinearPackedParamsBaseLinearPackedParamsBase41   LinearPackedParamsBase(
42       const int64_t out_features_block_size,
43       const int64_t in_features_block_size)
44       : out_features_block_size_(out_features_block_size),
45         in_features_block_size_(in_features_block_size) {}
46 
47   virtual at::Tensor apply(
48       const at::Tensor& input,
49       double output_scale,
50       int64_t output_zero_point) = 0;
51   virtual at::Tensor apply_relu(
52       const at::Tensor& input,
53       double output_scale,
54       int64_t output_zero_point) = 0;
55 
56   virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0;
57   virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0;
58 
59   virtual LinearPackedSerializationType unpack() = 0;
60 
61   virtual BCSRSerializationType serialize() = 0;
62 
63   virtual std::optional<at::Tensor> bias() = 0;
64 
set_biasLinearPackedParamsBase65   virtual void set_bias(const std::optional<at::Tensor>& bias) {
66     throw std::runtime_error(
67         "set_bias is not implemented for this packed "
68         "parameter type");
69   }
70 
71  protected:
72   const int64_t out_features_block_size_, in_features_block_size_;
73 };
74 
75 }}  // namespace ao::sparse
76