1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <torch/custom_class.h>
4
5 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_empty_per_channel_affine_quantized.h>
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/from_blob.h>
16 #endif
17
18 namespace ao {
19 namespace sparse {
20 int register_linear_params();
21
22 #ifdef USE_FBGEMM
23
unpack()24 LinearPackedSerializationType PackedLinearWeight::unpack() {
25 auto packW = w.get();
26
27 const int64_t N = static_cast<int64_t>(packW->R);
28 const int64_t K = static_cast<int64_t>(packW->C);
29
30 at::Tensor weight_origin;
31 if (q_scheme == c10::kPerTensorAffine) {
32 weight_origin = at::_empty_affine_quantized(
33 {N, K}, at::device(c10::kCPU).dtype(c10::kQInt8), w_scale[0], w_zp[0]);
34 } else if (q_scheme == c10::kPerChannelAffine) {
35 at::Tensor scales = at::empty(
36 {static_cast<long>(w_scale.size())},
37 at::device(c10::kCPU).dtype(c10::kFloat));
38 std::copy(w_scale.begin(), w_scale.end(), scales.mutable_data_ptr<float>());
39
40 at::Tensor zero_points = at::empty(
41 {static_cast<long>(w_zp.size())},
42 at::device(c10::kCPU).dtype(c10::kInt));
43 std::copy(w_zp.begin(), w_zp.end(), zero_points.mutable_data_ptr<int>());
44
45 weight_origin = at::_empty_per_channel_affine_quantized(
46 {N, K},
47 scales,
48 zero_points,
49 0, // The output channel axis is 0
50 device(c10::kCPU).dtype(c10::kQInt8));
51 }
52
53 int8_t* weight_ptr_int8 =
54 reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
55
56 packW->unpack(weight_ptr_int8);
57
58 const std::vector<int64_t> block_pattern(
59 {out_features_block_size_, in_features_block_size_});
60
61 return std::make_tuple(std::move(weight_origin), bias_, block_pattern);
62 }
63
64 #endif // USE_FBGEMM
65
66 #ifdef USE_PYTORCH_QNNPACK
67
unpack()68 LinearPackedSerializationType PackedLinearWeightQnnp::unpack() {
69 const int64_t N = static_cast<int64_t>(output_channels_);
70 const int64_t K = static_cast<int64_t>(input_channels_);
71
72 float* w_scales_ptr = w_scales_.data_ptr<float>();
73
74 at::Tensor weight_origin;
75 if (q_scheme_ == c10::kPerTensorAffine) {
76 weight_origin = at::_empty_affine_quantized(
77 {N, K},
78 at::device(c10::kCPU).dtype(c10::kQInt8),
79 w_scales_ptr[0],
80 w_zero_points_[0] - 128);
81 } else if (q_scheme_ == c10::kPerChannelAffine) {
82 at::Tensor scales = at::empty(
83 {static_cast<long>(output_channels_)},
84 at::device(c10::kCPU).dtype(c10::kFloat));
85 std::copy(
86 w_scales_ptr,
87 w_scales_ptr + output_channels_,
88 scales.mutable_data_ptr<float>());
89
90 at::Tensor zero_points = at::empty(
91 {static_cast<long>(output_channels_)},
92 at::device(c10::kCPU).dtype(c10::kInt));
93 std::transform(
94 w_zero_points_.begin(),
95 w_zero_points_.begin() + output_channels_,
96 zero_points.mutable_data_ptr<int>(),
97 [](uint8_t v) { return static_cast<int>(v) - 128; });
98
99 weight_origin = at::_empty_per_channel_affine_quantized(
100 {N, K},
101 scales,
102 zero_points,
103 0, // The output channel axis is 0
104 device(c10::kCPU).dtype(c10::kQInt8));
105 }
106
107 int8_t* weight_ptr_int8 =
108 reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
109
110 bcsr_matrix_->unpack(
111 weight_ptr_int8,
112 output_channels_,
113 input_channels_,
114 w_zero_points_.data());
115
116 std::vector<int64_t> block_pattern(
117 {out_features_block_size_, in_features_block_size_});
118
119 return std::make_tuple(
120 std::move(weight_origin), bias_, std::move(block_pattern));
121 }
122
123 #endif // USE_FBGEMM
124
125 namespace {
126
127 class QLinearUnpackWeightInt8 final {
128 public:
run(const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)129 static LinearPackedSerializationType run(
130 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
131 return packed_weight->unpack();
132 }
133 };
134
TORCH_LIBRARY_IMPL(sparse,CatchAll,m)135 TORCH_LIBRARY_IMPL(sparse, CatchAll, m) {
136 register_linear_params();
137 m.impl(
138 TORCH_SELECTIVE_NAME("sparse::qlinear_unpack"),
139 TORCH_FN(QLinearUnpackWeightInt8::run));
140 }
141 } // namespace
142 }} // namespace ao::sparse
143