xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <torch/custom_class.h>
4 
5 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_empty_per_channel_affine_quantized.h>
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/from_blob.h>
16 #endif
17 
18 namespace ao {
19 namespace sparse {
20 int register_linear_params();
21 
22 #ifdef USE_FBGEMM
23 
unpack()24 LinearPackedSerializationType PackedLinearWeight::unpack() {
25   auto packW = w.get();
26 
27   const int64_t N = static_cast<int64_t>(packW->R);
28   const int64_t K = static_cast<int64_t>(packW->C);
29 
30   at::Tensor weight_origin;
31   if (q_scheme == c10::kPerTensorAffine) {
32     weight_origin = at::_empty_affine_quantized(
33         {N, K}, at::device(c10::kCPU).dtype(c10::kQInt8), w_scale[0], w_zp[0]);
34   } else if (q_scheme == c10::kPerChannelAffine) {
35     at::Tensor scales = at::empty(
36         {static_cast<long>(w_scale.size())},
37         at::device(c10::kCPU).dtype(c10::kFloat));
38     std::copy(w_scale.begin(), w_scale.end(), scales.mutable_data_ptr<float>());
39 
40     at::Tensor zero_points = at::empty(
41         {static_cast<long>(w_zp.size())},
42         at::device(c10::kCPU).dtype(c10::kInt));
43     std::copy(w_zp.begin(), w_zp.end(), zero_points.mutable_data_ptr<int>());
44 
45     weight_origin = at::_empty_per_channel_affine_quantized(
46         {N, K},
47         scales,
48         zero_points,
49         0, // The output channel axis is 0
50         device(c10::kCPU).dtype(c10::kQInt8));
51   }
52 
53   int8_t* weight_ptr_int8 =
54       reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
55 
56   packW->unpack(weight_ptr_int8);
57 
58   const std::vector<int64_t> block_pattern(
59       {out_features_block_size_, in_features_block_size_});
60 
61   return std::make_tuple(std::move(weight_origin), bias_, block_pattern);
62 }
63 
64 #endif // USE_FBGEMM
65 
66 #ifdef USE_PYTORCH_QNNPACK
67 
unpack()68 LinearPackedSerializationType PackedLinearWeightQnnp::unpack() {
69   const int64_t N = static_cast<int64_t>(output_channels_);
70   const int64_t K = static_cast<int64_t>(input_channels_);
71 
72   float* w_scales_ptr = w_scales_.data_ptr<float>();
73 
74   at::Tensor weight_origin;
75   if (q_scheme_ == c10::kPerTensorAffine) {
76     weight_origin = at::_empty_affine_quantized(
77         {N, K},
78         at::device(c10::kCPU).dtype(c10::kQInt8),
79         w_scales_ptr[0],
80         w_zero_points_[0] - 128);
81   } else if (q_scheme_ == c10::kPerChannelAffine) {
82     at::Tensor scales = at::empty(
83         {static_cast<long>(output_channels_)},
84         at::device(c10::kCPU).dtype(c10::kFloat));
85     std::copy(
86         w_scales_ptr,
87         w_scales_ptr + output_channels_,
88         scales.mutable_data_ptr<float>());
89 
90     at::Tensor zero_points = at::empty(
91         {static_cast<long>(output_channels_)},
92         at::device(c10::kCPU).dtype(c10::kInt));
93     std::transform(
94         w_zero_points_.begin(),
95         w_zero_points_.begin() + output_channels_,
96         zero_points.mutable_data_ptr<int>(),
97         [](uint8_t v) { return static_cast<int>(v) - 128; });
98 
99     weight_origin = at::_empty_per_channel_affine_quantized(
100         {N, K},
101         scales,
102         zero_points,
103         0, // The output channel axis is 0
104         device(c10::kCPU).dtype(c10::kQInt8));
105   }
106 
107   int8_t* weight_ptr_int8 =
108       reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
109 
110   bcsr_matrix_->unpack(
111       weight_ptr_int8,
112       output_channels_,
113       input_channels_,
114       w_zero_points_.data());
115 
116   std::vector<int64_t> block_pattern(
117       {out_features_block_size_, in_features_block_size_});
118 
119   return std::make_tuple(
120       std::move(weight_origin), bias_, std::move(block_pattern));
121 }
122 
123 #endif // USE_FBGEMM
124 
125 namespace {
126 
127 class QLinearUnpackWeightInt8 final {
128  public:
run(const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)129   static LinearPackedSerializationType run(
130       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
131     return packed_weight->unpack();
132   }
133 };
134 
TORCH_LIBRARY_IMPL(sparse,CatchAll,m)135 TORCH_LIBRARY_IMPL(sparse, CatchAll, m) {
136   register_linear_params();
137   m.impl(
138       TORCH_SELECTIVE_NAME("sparse::qlinear_unpack"),
139       TORCH_FN(QLinearUnpackWeightInt8::run));
140 }
141 }  // namespace
142 }}  // namespace ao::sparse
143