xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/custom_class.h>
5 #include <torch/library.h>
6 
7 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
8 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
9 #include <c10/util/irange.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #endif
17 
18 namespace ao {
19 namespace sparse {
20 
21 int register_linear_params();
22 
23 #ifdef USE_FBGEMM
24 
25 template <bool ReluFused>
apply_impl(const at::Tensor & input,double output_scale,int64_t output_zero_point)26 at::Tensor PackedLinearWeight::apply_impl(
27     const at::Tensor& input,
28     double output_scale,
29     int64_t output_zero_point) {
30   // uint8 * int8 -> uint8 (no quantization/dequantization)
31 
32   // We make a strong guarantee that models using these operators will have
33   // the same numerics across different machines. Therefore, we do not provide
34   // a fallback path and rather fail loudly if we cannot run FBGEMM.
35   TORCH_CHECK(
36       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
37 
38   // TODO: contiguous is called for further jit optimizations.
39   auto input_contig = input.contiguous();
40   const auto* input_ptr =
41       reinterpret_cast<uint8_t*>(input_contig.data_ptr<c10::quint8>());
42 
43   TORCH_CHECK(
44       input.dim() >= 2,
45       "The dimension of input tensor should be larger than or equal to 2");
46   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
47   int64_t batch_size = size_to_dim_(input.dim() - 1, input.sizes());
48 
49   auto packW = w.get();
50 
51   int64_t out_channels = static_cast<int64_t>(packW->R);
52   int64_t K = input.size(input.dim() - 1);
53   TORCH_CHECK(
54       K == static_cast<int64_t>(packW->C),
55       "The number of columns in the packW should be equal to K: " +
56           std::to_string(K));
57 
58   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
59   float input_scale_float = input.q_scale();
60   int32_t input_zero_point_int32 = input.q_zero_point();
61 
62   std::vector<float> output_multiplier_float(1, 0.0);
63   std::vector<float> act_times_w_scale(1, 0.0);
64   TORCH_CHECK(
65       w_scale.size() == w_zp.size(),
66       "Weight scales and zero points vectors should have the same size.");
67   if (q_scheme == c10::kPerTensorAffine) {
68     // Process the per tensor quantization.
69     act_times_w_scale[0] = (input_scale_float * w_scale[0]);
70     output_multiplier_float[0] =
71         act_times_w_scale[0] / static_cast<float>(output_scale);
72   } else if (q_scheme == c10::kPerChannelAffine) {
73     // Process the per channel quantization.
74     output_multiplier_float.resize(out_channels, 0.0);
75     act_times_w_scale.resize(out_channels, 1.0f);
76     for (const auto i : c10::irange(out_channels)) {
77       act_times_w_scale[i] = (input_scale_float * w_scale[i]);
78       output_multiplier_float[i] =
79           act_times_w_scale[i] / static_cast<float>(output_scale);
80     }
81   }
82   int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
83 
84   const float* bias_ptr = nullptr;
85   at::Tensor bias;
86   if (this->bias_.has_value()) {
87     bias = this->bias_.value();
88     bias = bias.contiguous();
89     TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
90     TORCH_CHECK(
91         bias.size(0) == out_channels,
92         "bias should have out_channels elements: " +
93             std::to_string(out_channels));
94     bias_ptr = reinterpret_cast<float*>(bias.data_ptr<float>());
95   }
96 
97   // The resulting matrix here is 2-D, let's view it with the original
98   // left hand dimensions of the input. Here are two examples:
99   // 1. If the input tensor is {batch_size, K}, the output tensor is
100   // {batch_size, out_channels}.
101   // 2. If the input tensor is {x, batch_size, K}, the output tensor is {x,
102   // batch_size, out_channels}.
103   std::vector<int64_t> out_sizes = input.sizes().vec();
104   out_sizes.back() = out_channels; // NOLINT
105   // Allocate output Tensor and a buffer for fbgemmPacked to use
106   auto output_tr = at::_empty_affine_quantized(
107       out_sizes,
108       at::device(c10::kCPU).dtype(c10::kQUInt8),
109       output_scale,
110       output_zero_point);
111   auto output = at::_empty_affine_quantized(
112       out_sizes,
113       at::device(c10::kCPU).dtype(c10::kQUInt8),
114       output_scale,
115       output_zero_point);
116 
117   auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
118 
119   // fbgemm kernel computes the following:
120   // C(output) = A(weight) x B(input), where C, A, B are out_channels x
121   // batch_size, out_channels x K, K x batch_size matrices, respectively.
122   // Therefore we need to transpose input
123   auto input_tr = at::_empty_affine_quantized(
124       input.sizes(),
125       at::device(c10::kCPU).dtype(c10::kQUInt8),
126       input_scale_float,
127       input_zero_point_int32);
128 
129   auto* input_tr_ptr =
130       reinterpret_cast<uint8_t*>(input_tr.data_ptr<c10::quint8>());
131   // TODO: Activation transpose before and after the kernel can be removed if we
132   // keep activation tensor always tranposed.
133   fbgemm::transpose_simd<uint8_t>(
134       batch_size, K, input_ptr, K, input_tr_ptr, batch_size);
135 
136   int num_tasks = at::get_num_threads();
137   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
138     for (const auto task_id : c10::irange(begin, end)) {
139       fbgemm::trRequantizationParams_t reqParams = {
140           input_zero_point_int32,
141           w_zp.data(),
142           output_zero_point_int32,
143           static_cast<float>(output_scale),
144           col_offsets.data(),
145           /*activation offsets*/ nullptr,
146           bias_ptr,
147           act_times_w_scale.data()};
148 
149       if (q_scheme == c10::kPerTensorAffine) {
150         // Process the per tensor quantization.
151         //
152         // After the uint8 * int8 matrix multiplication is performed, this
153         // operation does:
154         //  1) Add in row and column offsets to the rows and columns,
155         //  respectively.
156         //  2) Add in the bias term.
157 
158         // Do the GEMM
159         fbgemm::fbgemmSparseDenseInt8MM<
160             ReluFused,
161             fbgemm::QuantizationGranularity::TENSOR>(
162             batch_size,
163             w,
164             input_tr_ptr,
165             /*ldb=*/batch_size,
166             /*C_i32=*/buffer.data_ptr<int32_t>(),
167             /*C_u8=*/
168             reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
169             /*ldc=*/batch_size,
170             /*rParams=*/reqParams,
171             /*accum=*/false,
172             /*thread_id=*/task_id,
173             /*num_threads=*/num_tasks);
174       } else if (q_scheme == c10::kPerChannelAffine) {
175         // Process the per channel quantization.
176         //
177         // After the uint8 * int8 matrix multiplication is performed, this
178         // operation does:
179         //  1) Add in row and column offsets to the rows and columns,
180         //  respectively.
181         //  2) Add in the bias term.
182 
183         // Do the GEMM
184         fbgemm::fbgemmSparseDenseInt8MM<
185             ReluFused,
186             fbgemm::QuantizationGranularity::OUT_CHANNEL>(
187             batch_size,
188             w,
189             input_tr_ptr,
190             /*ldb=*/batch_size,
191             /*C_i32=*/buffer.data_ptr<int32_t>(),
192             /*C_u8=*/
193             reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
194             /*ldc=*/batch_size,
195             /*rParams=*/reqParams,
196             /*accum*/ false,
197             /*thread_id=*/task_id,
198             /*num_threads=*/num_tasks);
199       }
200     }
201   });
202 
203   // transpose output_tr back to batch_size x out_channels
204   fbgemm::transpose_simd<uint8_t>(
205       out_channels,
206       batch_size,
207       reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
208       batch_size,
209       reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
210       out_channels);
211 
212   return output;
213 }
214 
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)215 at::Tensor PackedLinearWeight::apply(
216     const at::Tensor& input,
217     double output_scale,
218     int64_t output_zero_point) {
219   return apply_impl<false>(input, output_scale, output_zero_point);
220 }
221 
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)222 at::Tensor PackedLinearWeight::apply_relu(
223     const at::Tensor& input,
224     double output_scale,
225     int64_t output_zero_point) {
226   return apply_impl<true>(input, output_scale, output_zero_point);
227 }
228 
229 #endif // USE_FBGEMM
230 
231 namespace {
232 
233 template <bool ReluFused>
234 class QLinearInt8 final {
235  public:
run(const at::Tensor & input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)236   static at::Tensor run(
237       const at::Tensor& input,
238       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
239       double output_scale,
240       int64_t output_zero_point) {
241     if (ReluFused) {
242       return packed_weight->apply_relu(input, output_scale, output_zero_point);
243     } else {
244       return packed_weight->apply(input, output_scale, output_zero_point);
245     }
246   }
247 };
248 
TORCH_LIBRARY_IMPL(sparse,QuantizedCPU,m)249 TORCH_LIBRARY_IMPL(sparse, QuantizedCPU, m) {
250   register_linear_params();
251   m.impl(
252       TORCH_SELECTIVE_NAME("sparse::qlinear"),
253       TORCH_FN(QLinearInt8<false>::run));
254   m.impl(
255       TORCH_SELECTIVE_NAME("sparse::qlinear_relu"),
256       TORCH_FN(QLinearInt8<true>::run));
257 }
258 
259 } // namespace
260 }} // namespace ao::sparse
261