1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/custom_class.h>
5 #include <torch/library.h>
6
7 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
8 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
9 #include <c10/util/irange.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #endif
17
18 namespace ao {
19 namespace sparse {
20
21 int register_linear_params();
22
23 #ifdef USE_FBGEMM
24
25 template <bool ReluFused>
apply_impl(const at::Tensor & input,double output_scale,int64_t output_zero_point)26 at::Tensor PackedLinearWeight::apply_impl(
27 const at::Tensor& input,
28 double output_scale,
29 int64_t output_zero_point) {
30 // uint8 * int8 -> uint8 (no quantization/dequantization)
31
32 // We make a strong guarantee that models using these operators will have
33 // the same numerics across different machines. Therefore, we do not provide
34 // a fallback path and rather fail loudly if we cannot run FBGEMM.
35 TORCH_CHECK(
36 fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
37
38 // TODO: contiguous is called for further jit optimizations.
39 auto input_contig = input.contiguous();
40 const auto* input_ptr =
41 reinterpret_cast<uint8_t*>(input_contig.data_ptr<c10::quint8>());
42
43 TORCH_CHECK(
44 input.dim() >= 2,
45 "The dimension of input tensor should be larger than or equal to 2");
46 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
47 int64_t batch_size = size_to_dim_(input.dim() - 1, input.sizes());
48
49 auto packW = w.get();
50
51 int64_t out_channels = static_cast<int64_t>(packW->R);
52 int64_t K = input.size(input.dim() - 1);
53 TORCH_CHECK(
54 K == static_cast<int64_t>(packW->C),
55 "The number of columns in the packW should be equal to K: " +
56 std::to_string(K));
57
58 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
59 float input_scale_float = input.q_scale();
60 int32_t input_zero_point_int32 = input.q_zero_point();
61
62 std::vector<float> output_multiplier_float(1, 0.0);
63 std::vector<float> act_times_w_scale(1, 0.0);
64 TORCH_CHECK(
65 w_scale.size() == w_zp.size(),
66 "Weight scales and zero points vectors should have the same size.");
67 if (q_scheme == c10::kPerTensorAffine) {
68 // Process the per tensor quantization.
69 act_times_w_scale[0] = (input_scale_float * w_scale[0]);
70 output_multiplier_float[0] =
71 act_times_w_scale[0] / static_cast<float>(output_scale);
72 } else if (q_scheme == c10::kPerChannelAffine) {
73 // Process the per channel quantization.
74 output_multiplier_float.resize(out_channels, 0.0);
75 act_times_w_scale.resize(out_channels, 1.0f);
76 for (const auto i : c10::irange(out_channels)) {
77 act_times_w_scale[i] = (input_scale_float * w_scale[i]);
78 output_multiplier_float[i] =
79 act_times_w_scale[i] / static_cast<float>(output_scale);
80 }
81 }
82 int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
83
84 const float* bias_ptr = nullptr;
85 at::Tensor bias;
86 if (this->bias_.has_value()) {
87 bias = this->bias_.value();
88 bias = bias.contiguous();
89 TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
90 TORCH_CHECK(
91 bias.size(0) == out_channels,
92 "bias should have out_channels elements: " +
93 std::to_string(out_channels));
94 bias_ptr = reinterpret_cast<float*>(bias.data_ptr<float>());
95 }
96
97 // The resulting matrix here is 2-D, let's view it with the original
98 // left hand dimensions of the input. Here are two examples:
99 // 1. If the input tensor is {batch_size, K}, the output tensor is
100 // {batch_size, out_channels}.
101 // 2. If the input tensor is {x, batch_size, K}, the output tensor is {x,
102 // batch_size, out_channels}.
103 std::vector<int64_t> out_sizes = input.sizes().vec();
104 out_sizes.back() = out_channels; // NOLINT
105 // Allocate output Tensor and a buffer for fbgemmPacked to use
106 auto output_tr = at::_empty_affine_quantized(
107 out_sizes,
108 at::device(c10::kCPU).dtype(c10::kQUInt8),
109 output_scale,
110 output_zero_point);
111 auto output = at::_empty_affine_quantized(
112 out_sizes,
113 at::device(c10::kCPU).dtype(c10::kQUInt8),
114 output_scale,
115 output_zero_point);
116
117 auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
118
119 // fbgemm kernel computes the following:
120 // C(output) = A(weight) x B(input), where C, A, B are out_channels x
121 // batch_size, out_channels x K, K x batch_size matrices, respectively.
122 // Therefore we need to transpose input
123 auto input_tr = at::_empty_affine_quantized(
124 input.sizes(),
125 at::device(c10::kCPU).dtype(c10::kQUInt8),
126 input_scale_float,
127 input_zero_point_int32);
128
129 auto* input_tr_ptr =
130 reinterpret_cast<uint8_t*>(input_tr.data_ptr<c10::quint8>());
131 // TODO: Activation transpose before and after the kernel can be removed if we
132 // keep activation tensor always tranposed.
133 fbgemm::transpose_simd<uint8_t>(
134 batch_size, K, input_ptr, K, input_tr_ptr, batch_size);
135
136 int num_tasks = at::get_num_threads();
137 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
138 for (const auto task_id : c10::irange(begin, end)) {
139 fbgemm::trRequantizationParams_t reqParams = {
140 input_zero_point_int32,
141 w_zp.data(),
142 output_zero_point_int32,
143 static_cast<float>(output_scale),
144 col_offsets.data(),
145 /*activation offsets*/ nullptr,
146 bias_ptr,
147 act_times_w_scale.data()};
148
149 if (q_scheme == c10::kPerTensorAffine) {
150 // Process the per tensor quantization.
151 //
152 // After the uint8 * int8 matrix multiplication is performed, this
153 // operation does:
154 // 1) Add in row and column offsets to the rows and columns,
155 // respectively.
156 // 2) Add in the bias term.
157
158 // Do the GEMM
159 fbgemm::fbgemmSparseDenseInt8MM<
160 ReluFused,
161 fbgemm::QuantizationGranularity::TENSOR>(
162 batch_size,
163 w,
164 input_tr_ptr,
165 /*ldb=*/batch_size,
166 /*C_i32=*/buffer.data_ptr<int32_t>(),
167 /*C_u8=*/
168 reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
169 /*ldc=*/batch_size,
170 /*rParams=*/reqParams,
171 /*accum=*/false,
172 /*thread_id=*/task_id,
173 /*num_threads=*/num_tasks);
174 } else if (q_scheme == c10::kPerChannelAffine) {
175 // Process the per channel quantization.
176 //
177 // After the uint8 * int8 matrix multiplication is performed, this
178 // operation does:
179 // 1) Add in row and column offsets to the rows and columns,
180 // respectively.
181 // 2) Add in the bias term.
182
183 // Do the GEMM
184 fbgemm::fbgemmSparseDenseInt8MM<
185 ReluFused,
186 fbgemm::QuantizationGranularity::OUT_CHANNEL>(
187 batch_size,
188 w,
189 input_tr_ptr,
190 /*ldb=*/batch_size,
191 /*C_i32=*/buffer.data_ptr<int32_t>(),
192 /*C_u8=*/
193 reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
194 /*ldc=*/batch_size,
195 /*rParams=*/reqParams,
196 /*accum*/ false,
197 /*thread_id=*/task_id,
198 /*num_threads=*/num_tasks);
199 }
200 }
201 });
202
203 // transpose output_tr back to batch_size x out_channels
204 fbgemm::transpose_simd<uint8_t>(
205 out_channels,
206 batch_size,
207 reinterpret_cast<uint8_t*>(output_tr.data_ptr<c10::quint8>()),
208 batch_size,
209 reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
210 out_channels);
211
212 return output;
213 }
214
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)215 at::Tensor PackedLinearWeight::apply(
216 const at::Tensor& input,
217 double output_scale,
218 int64_t output_zero_point) {
219 return apply_impl<false>(input, output_scale, output_zero_point);
220 }
221
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)222 at::Tensor PackedLinearWeight::apply_relu(
223 const at::Tensor& input,
224 double output_scale,
225 int64_t output_zero_point) {
226 return apply_impl<true>(input, output_scale, output_zero_point);
227 }
228
229 #endif // USE_FBGEMM
230
231 namespace {
232
233 template <bool ReluFused>
234 class QLinearInt8 final {
235 public:
run(const at::Tensor & input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)236 static at::Tensor run(
237 const at::Tensor& input,
238 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
239 double output_scale,
240 int64_t output_zero_point) {
241 if (ReluFused) {
242 return packed_weight->apply_relu(input, output_scale, output_zero_point);
243 } else {
244 return packed_weight->apply(input, output_scale, output_zero_point);
245 }
246 }
247 };
248
TORCH_LIBRARY_IMPL(sparse,QuantizedCPU,m)249 TORCH_LIBRARY_IMPL(sparse, QuantizedCPU, m) {
250 register_linear_params();
251 m.impl(
252 TORCH_SELECTIVE_NAME("sparse::qlinear"),
253 TORCH_FN(QLinearInt8<false>::run));
254 m.impl(
255 TORCH_SELECTIVE_NAME("sparse::qlinear_relu"),
256 TORCH_FN(QLinearInt8<true>::run));
257 }
258
259 } // namespace
260 }} // namespace ao::sparse
261