1 #pragma once 2 3 #include <cstdint> 4 5 #include <ATen/core/ivalue.h> 6 7 namespace ao { 8 namespace sparse { 9 10 // <Weight, bias, out_features_block_size, in_features_block_size> 11 using LinearPackedSerializationType = 12 std::tuple<at::Tensor, std::optional<at::Tensor>, std::vector<int64_t>>; 13 14 #define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2 15 16 using BCSRSerializationType = 17 std::tuple< 18 int64_t, // Serialization Version 19 std::optional<at::Tensor>, // Bias 20 int64_t, // Out Features (Row) Block Size 21 int64_t, // In Features (Column) Block Size 22 at::Tensor, // Weight Scales (single element vector if per-tensor) (float) 23 at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t) 24 bool, // Quantization Scheme (true: per tensor, false: per channel) 25 at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t) 26 at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t) 27 at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t) 28 int64_t, // Number of Output Channels 29 int64_t // Number of Input Channels 30 >; 31 32 using BCSR = 33 std::tuple< 34 std::vector<int8_t>, // Non-Zero Weight Values 35 std::vector<int32_t>, // Compressed Row Block Indices 36 std::vector<int32_t> // Column Block Indices 37 >; 38 39 struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { 40 public: LinearPackedParamsBaseLinearPackedParamsBase41 LinearPackedParamsBase( 42 const int64_t out_features_block_size, 43 const int64_t in_features_block_size) 44 : out_features_block_size_(out_features_block_size), 45 in_features_block_size_(in_features_block_size) {} 46 47 virtual at::Tensor apply( 48 const at::Tensor& input, 49 double output_scale, 50 int64_t output_zero_point) = 0; 51 virtual at::Tensor apply_relu( 52 const at::Tensor& input, 53 double output_scale, 54 int64_t output_zero_point) = 0; 55 56 virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0; 57 virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0; 58 59 virtual LinearPackedSerializationType unpack() = 0; 60 61 virtual BCSRSerializationType serialize() = 0; 62 63 virtual std::optional<at::Tensor> bias() = 0; 64 set_biasLinearPackedParamsBase65 virtual void set_bias(const std::optional<at::Tensor>& bias) { 66 throw std::runtime_error( 67 "set_bias is not implemented for this packed " 68 "parameter type"); 69 } 70 71 protected: 72 const int64_t out_features_block_size_, in_features_block_size_; 73 }; 74 75 }} // namespace ao::sparse 76