1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/TensorBody.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type_base.h>
8 #include <ATen/native/quantized/PackedParams.h>
9 #include <ATen/native/quantized/cpu/conv_serialization.h>
10 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
11 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
12 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
13 #include <ATen/native/quantized/cpu/OnednnUtils.h>
14 #include <ATen/native/TensorFactories.h>
15 #include <ATen/quantized/QTensorImpl.h>
16 #include <ATen/quantized/Quantizer.h>
17 #include <c10/core/QScheme.h>
18 #include <c10/core/TensorOptions.h>
19 #include <c10/util/accumulate.h>
20 #include <c10/util/irange.h>
21 #include <torch/custom_class.h>
22
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #else
26 #include <ATen/ops/cat.h>
27
28 #include <utility>
29 #endif
30
31 int register_embedding_params();
32
33 #ifdef USE_FBGEMM
34
35 namespace at {
36 namespace native {
37 namespace fbgemm_utils {
38
39 namespace {
40
IsChannelsLast3d(const Tensor & tensor)41 bool IsChannelsLast3d(const Tensor& tensor) {
42 if (tensor.dim() != 5) {
43 return false;
44 }
45 const int64_t C = tensor.size(1);
46 const int64_t D = tensor.size(2);
47 const int64_t H = tensor.size(3);
48 const int64_t W = tensor.size(4);
49 return tensor.stride(0) == D * H * W * C && tensor.stride(1) == 1 &&
50 tensor.stride(2) == H * W * C && tensor.stride(3) == W * C &&
51 tensor.stride(4) == C;
52 }
53
54 template <typename T>
CopyToChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const T * src,T * dst)55 void CopyToChannelsLast3dTensor(
56 int64_t N,
57 int64_t C,
58 int64_t D,
59 int64_t H,
60 int64_t W,
61 const T* src,
62 T* dst) {
63 const int64_t inner_size = D * H * W;
64 for (const auto i : c10::irange(N)) {
65 for (const auto j : c10::irange(inner_size)) {
66 for (const auto k : c10::irange(C)) {
67 dst[(i * inner_size + j) * C + k] = src[(i * C + k) * inner_size + j];
68 }
69 }
70 }
71 }
72
73 template <typename T>
CopyICFirst3dTensorToChannelsLast3dTensor(int64_t G,int64_t IC_G,int64_t OC_G,int64_t D,int64_t H,int64_t W,const T * src,T * dst)74 void CopyICFirst3dTensorToChannelsLast3dTensor(
75 int64_t G,
76 int64_t IC_G,
77 int64_t OC_G,
78 int64_t D,
79 int64_t H,
80 int64_t W,
81 const T* src,
82 T* dst) {
83 // IC OC/G THW -> G OC/G THW IC/G
84 const int64_t inner_size = D * H * W;
85 for (int64_t i = 0; i < G * OC_G; ++i) {
86 for (const auto j : c10::irange(inner_size)) {
87 for (const auto ic : c10::irange(IC_G)) {
88 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
89 int g = i / OC_G;
90 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
91 int oc = i % OC_G;
92 dst[(i * inner_size + j) * IC_G + ic] =
93 src[((g * IC_G + ic) * OC_G + oc) * inner_size + j];
94 }
95 }
96 }
97 }
98
99 } // namespace
100
101 template <int kSpatialDim>
MakeFbgemmConvParam(int N,int C,int M,const std::vector<int> & image_shape,int groups,const std::vector<int> & kernels,const std::vector<int> & strides,const std::vector<int> & pads,const std::vector<int> & dilations,const std::vector<int> & output_padding,bool transposed)102 fbgemm::conv_param_t<kSpatialDim> MakeFbgemmConvParam(
103 int N,
104 int C,
105 int M,
106 const std::vector<int>& image_shape,
107 int groups,
108 const std::vector<int>& kernels,
109 const std::vector<int>& strides,
110 const std::vector<int>& pads,
111 const std::vector<int>& dilations,
112 const std::vector<int>& output_padding,
113 bool transposed) {
114 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
115 std::array<int, kSpatialDim> image_shape_;
116 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
117 std::array<int, kSpatialDim> kernels_;
118 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
119 std::array<int, kSpatialDim> strides_;
120 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
121 std::array<int, kSpatialDim * 2> pads_;
122 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
123 std::array<int, kSpatialDim> dilations_;
124 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
125 std::array<int, kSpatialDim> output_padding_;
126 std::move(image_shape.begin(), image_shape.begin() + image_shape.size(), image_shape_.begin());
127 std::move(
128 kernels.begin(), kernels.begin() + kernels.size(), kernels_.begin());
129 std::move(
130 strides.begin(), strides.begin() + strides.size(), strides_.begin());
131 std::move(
132 dilations.begin(),
133 dilations.begin() + dilations.size(),
134 dilations_.begin());
135 std::move(
136 output_padding.begin(),
137 output_padding.begin() + output_padding.size(),
138 output_padding_.begin());
139 std::copy(pads.begin(), pads.begin() + pads.size(), pads_.begin());
140 std::move(pads.begin(), pads.begin() + pads.size(), pads_.begin() + pads.size());
141
142 return fbgemm::conv_param_t<kSpatialDim>(
143 N, // batch size
144 C, // input channels
145 M, // output channels
146 image_shape_, // feature map size
147 groups, // groups
148 kernels_, // kernels
149 strides_, // strides
150 pads_, // paddings
151 dilations_, // dilations
152 output_padding_, // output paddings for conv transpose
153 transposed);
154 }
155
MakeStridedQTensorCPU(const IntArrayRef & sizes,const IntArrayRef & strides,const TensorOptions & options,QuantizerPtr quantizer)156 Tensor MakeStridedQTensorCPU(
157 const IntArrayRef& sizes,
158 const IntArrayRef& strides,
159 const TensorOptions& options,
160 QuantizerPtr quantizer) {
161 AT_ASSERT(options.device().is_cpu());
162 at::native::check_size_nonnegative(sizes);
163 auto* allocator = at::getCPUAllocator();
164 const int64_t nelements = c10::multiply_integers(sizes);
165 auto dtype = options.dtype();
166 TORCH_CHECK(
167 isQIntType(typeMetaToScalarType(dtype)),
168 "ScalarType is not supported in new_qtensor_cpu.");
169 int64_t size_bytes = nelements * dtype.itemsize();
170 auto storage = c10::make_intrusive<StorageImpl>(
171 StorageImpl::use_byte_size_t(),
172 size_bytes,
173 allocator->allocate(size_bytes),
174 allocator,
175 /* resizable = */ true);
176 constexpr auto quantized_cpu_ks = at::DispatchKeySet(at::DispatchKey::QuantizedCPU);
177 auto tensor = detail::make_tensor<QTensorImpl>(
178 storage,
179 quantized_cpu_ks,
180 dtype,
181 quantizer);
182 get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides);
183 return tensor;
184 }
185
MakeEmptyAffineQuantizedChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const TensorOptions & options,double scale,int64_t zero_point)186 Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor(
187 int64_t N,
188 int64_t C,
189 int64_t D,
190 int64_t H,
191 int64_t W,
192 const TensorOptions& options,
193 double scale,
194 int64_t zero_point) {
195 return MakeStridedQTensorCPU(
196 {N, C, D, H, W},
197 {D * H * W * C, 1, H * W * C, W * C, C},
198 options,
199 make_per_tensor_affine_quantizer(
200 scale, zero_point, typeMetaToScalarType(options.dtype())));
201 }
202
MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const TensorOptions & options,const Tensor & scales,const Tensor & zero_points)203 Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(
204 int64_t N,
205 int64_t C,
206 int64_t D,
207 int64_t H,
208 int64_t W,
209 const TensorOptions& options,
210 const Tensor& scales,
211 const Tensor& zero_points) {
212 return MakeStridedQTensorCPU(
213 {N, C, D, H, W},
214 {D * H * W * C, 1, H * W * C, W * C, C},
215 options,
216 make_per_channel_affine_quantizer(
217 scales,
218 zero_points,
219 0, // axis
220 typeMetaToScalarType(options.dtype())));
221 }
222
ConvertToChannelsLast3dTensor(const Tensor & src)223 Tensor ConvertToChannelsLast3dTensor(const Tensor& src) {
224 TORCH_CHECK(src.dim() == 5);
225 Tensor dst;
226 if (IsChannelsLast3d(src)) {
227 dst = src;
228 } else {
229 const int64_t N = src.size(0);
230 const int64_t C = src.size(1);
231 const int64_t D = src.size(2);
232 const int64_t H = src.size(3);
233 const int64_t W = src.size(4);
234 dst = MakeStridedQTensorCPU(
235 {N, C, D, H, W},
236 {D * H * W * C, 1, H * W * C, W * C, C},
237 src.options(),
238 src.quantizer());
239 AT_DISPATCH_QINT_TYPES(
240 src.scalar_type(), "ConvertToChannelsLast3dTensor", [&]() {
241 const Tensor src_contig = src.contiguous();
242 CopyToChannelsLast3dTensor<scalar_t>(
243 N,
244 C,
245 D,
246 H,
247 W,
248 src_contig.data_ptr<scalar_t>(),
249 dst.data_ptr<scalar_t>());
250 });
251 }
252 return dst;
253 }
254
255 template <>
TransposeConvTensorUnpackConversion(const Tensor & src,int groups)256 Tensor TransposeConvTensorUnpackConversion<2>(const Tensor& src, int groups) {
257 // OC IC/G HW -> IC OC/G HW logically
258 auto oc_g_ic_g_hw_tensors = src.chunk(groups);
259 auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1);
260 set_quantizer_(fused_tensor, src.quantizer());
261 return fused_tensor.permute({1, 0, 2, 3});
262 }
263
264 template fbgemm::conv_param_t<1> MakeFbgemmConvParam<1>(
265 int N,
266 int C,
267 int M,
268 const std::vector<int>& image_shape,
269 int groups,
270 const std::vector<int>& kernels,
271 const std::vector<int>& strides,
272 const std::vector<int>& pads,
273 const std::vector<int>& dilations,
274 const std::vector<int>& output_padding,
275 bool transposed);
276
277 template fbgemm::conv_param_t<2> MakeFbgemmConvParam<2>(
278 int N,
279 int C,
280 int M,
281 const std::vector<int>& image_shape,
282 int groups,
283 const std::vector<int>& kernels,
284 const std::vector<int>& strides,
285 const std::vector<int>& pads,
286 const std::vector<int>& dilations,
287 const std::vector<int>& output_padding,
288 bool transposed);
289
290 template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>(
291 int N,
292 int C,
293 int M,
294 const std::vector<int>& image_shape,
295 int groups,
296 const std::vector<int>& kernels,
297 const std::vector<int>& strides,
298 const std::vector<int>& pads,
299 const std::vector<int>& dilations,
300 const std::vector<int>& output_padding,
301 bool transposed);
302 template <>
TransposeConvTensorUnpackConversion(const Tensor & src,int groups)303 Tensor TransposeConvTensorUnpackConversion<3>(const Tensor& src, int groups) {
304 // OC IC/G DHW -> IC OC/G DHW logically
305 auto oc_g_ic_g_hw_tensors = src.chunk(groups);
306 auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1);
307 set_quantizer_(fused_tensor, src.quantizer());
308 return fused_tensor.permute({1, 0, 2, 3, 4});
309 }
310
311 template <>
ConvertConvWeightsToChannelLastTensor(const at::Tensor & src,int groups,bool transpose)312 Tensor ConvertConvWeightsToChannelLastTensor<2>(
313 const at::Tensor& src,
314 int groups,
315 bool transpose) {
316 return transpose ?
317 // 2D conv transpose weight transform
318 // IC OC/G KH KW -> G OC/G KH KW IC/G
319 [&]() {
320 auto ic_g_oc_g_hw_tensors = src.chunk(groups);
321 for (auto& tensor : ic_g_oc_g_hw_tensors) {
322 tensor = tensor.unsqueeze(0);
323 }
324 auto fused_tensor = at::cat(ic_g_oc_g_hw_tensors);
325 set_quantizer_(fused_tensor, src.quantizer());
326 return fused_tensor.permute({0, 2, 3, 4, 1})
327 .contiguous(c10::MemoryFormat::Contiguous);
328 }()
329 // 2d conv weight transform
330 : src.contiguous(c10::MemoryFormat::ChannelsLast);
331 }
332
333 template <>
ConvertConvWeightsToChannelLastTensor(const at::Tensor & src,int groups,bool transpose)334 Tensor ConvertConvWeightsToChannelLastTensor<3>(
335 const at::Tensor& src,
336 int groups,
337 bool transpose) {
338 if (!transpose) {
339 return ConvertToChannelsLast3dTensor(src);
340 } else {
341 TORCH_CHECK(src.dim() == 5);
342 Tensor dst;
343 const int64_t N = src.size(0);
344 const int64_t IC_G = N / groups;
345 const int64_t OC_G = src.size(1);
346 const int64_t D = src.size(2);
347 const int64_t H = src.size(3);
348 const int64_t W = src.size(4);
349 dst = MakeStridedQTensorCPU(
350 {groups * OC_G, IC_G, D, H, W},
351 {D * H * W * IC_G, 1, H * W * IC_G, W * IC_G, IC_G},
352 src.options(),
353 src.quantizer());
354 AT_DISPATCH_QINT_TYPES(
355 src.scalar_type(), "CopyICFirst3dTensorToChannelsLast3dTensor", [&]() {
356 const Tensor src_contig = src.contiguous();
357 CopyICFirst3dTensorToChannelsLast3dTensor<scalar_t>(
358 groups,
359 IC_G,
360 OC_G,
361 D,
362 H,
363 W,
364 src_contig.data_ptr<scalar_t>(),
365 dst.data_ptr<scalar_t>());
366 });
367 return dst;
368 }
369 }
370
371 } // namespace fbgemm_utils
372 } // namespace native
373 } // namespace at
374
375
376 #endif // USE_FBGEMM
377
378 namespace {
379 // This is really terrible, but couldnt figure out a better way to constexpr convert int to
380 // string and then perform string concatenation on/with it
_hack_int_to_class_name(int x)381 constexpr const char* _hack_int_to_class_name(int x) {
382 switch(x) {
383 case 2:
384 return "Conv2dPackedParamsBase";
385 case 3:
386 return "Conv3dPackedParamsBase";
387 default:
388 assert(false);
389 return "NotAValidDimension";
390 }
391 }
392 }
393
394 template <int kSpatialDim = 2>
395 TORCH_API int
register_conv_params()396 register_conv_params() {
397 static auto register_conv_params =
398 torch::selective_class_<ConvPackedParamsBase<kSpatialDim>>(
399 "quantized", TORCH_SELECTIVE_CLASS(_hack_int_to_class_name(kSpatialDim)))
400 .def_pickle(
401 [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params)
402 -> ConvParamsSerializationType { // __getstate__
403 return serialize_conv<kSpatialDim>(params);
404 },
405 // __setstate__ takes c10::IValue because we support parsing historical
406 // serialization versions.
407 [](c10::IValue v)
408 -> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> { // __setstate__
409 ConvParamsSerializationTypeV3 state = parse_conv_serialized_state<kSpatialDim>(v);
410 return deserialize_conv<kSpatialDim>(state);
411 })
412 .def("weight", [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& self) {
413 return std::get<0>(self->unpack());
414 })
415 .def("bias", [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& self) {
416 return std::get<1>(self->unpack());
417 })
418 .def("unpack", &ConvPackedParamsBase<kSpatialDim>::unpack)
419 .def("stride", &ConvPackedParamsBase<kSpatialDim>::stride)
420 .def("padding", &ConvPackedParamsBase<kSpatialDim>::padding)
421 .def("output_padding", &ConvPackedParamsBase<kSpatialDim>::output_padding)
422 .def("dilation", &ConvPackedParamsBase<kSpatialDim>::dilation)
423 .def("groups", &ConvPackedParamsBase<kSpatialDim>::groups)
424 .def("transpose", &ConvPackedParamsBase<kSpatialDim>::transpose);
425 return 0;
426 }
427
428 template
429 TORCH_API int register_conv_params<2>();
430 template
431 TORCH_API int register_conv_params<3>();
432
433 TORCH_API int register_linear_params();
434
register_linear_params()435 TORCH_API int register_linear_params() {
436 using SerializationType = std::tuple<at::Tensor, std::optional<at::Tensor>>;
437 static auto register_linear_params =
438 torch::selective_class_<LinearPackedParamsBase>(
439 "quantized", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase"))
440 .def_pickle(
441 [](const c10::intrusive_ptr<LinearPackedParamsBase>& params)
442 -> SerializationType { // __getstate__
443 return params->unpack();
444 },
445 [](SerializationType state)
446 -> c10::intrusive_ptr<
447 LinearPackedParamsBase> { // __setstate__
448 at::Tensor weight;
449 std::optional<at::Tensor> bias;
450 weight = std::move(std::get<0>(state));
451 bias = std::move(std::get<1>(state));
452
453 #ifdef USE_FBGEMM
454 if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
455 at::globalContext().qEngine() == at::QEngine::X86) {
456 if (weight.scalar_type() == at::kQInt8) {
457 return PackedLinearWeight::prepack(
458 std::move(weight), std::move(bias));
459 } else if (weight.scalar_type() == at::kFloat) {
460 // NB: fp16 weight is serialized as float
461 return PackedLinearWeightFp16::prepack(
462 std::move(weight), std::move(bias));
463 } else {
464 TORCH_CHECK(
465 false,
466 "Unsupported data type",
467 c10::toString(weight.scalar_type()),
468 " in serialized LinearPackedParams object!");
469 }
470 }
471 #endif // USE_FBGEMM
472 #ifdef USE_PYTORCH_QNNPACK
473 if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
474 TORCH_CHECK(
475 weight.scalar_type() == at::kQInt8,
476 "QNNPACK only supports INT8 bit width currently. Got ",
477 c10::toString(weight.scalar_type()));
478 return PackedLinearWeightsQnnp::prepack(
479 std::move(weight), std::move(bias));
480 }
481 #endif // USE_PYTORCH_QNNPACK
482 #if AT_MKLDNN_ENABLED()
483 if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
484 TORCH_CHECK(
485 weight.scalar_type() == at::kQInt8,
486 "ONEDNN only supports INT8 bit width currently. Got ",
487 c10::toString(weight.scalar_type()));
488 return PackedLinearWeightsOnednn::prepack(
489 std::move(weight), std::move(bias));
490 }
491 #endif // #if AT_MKLDNN_ENABLED()
492 TORCH_CHECK(false, "Unknown qengine");
493 })
494 .def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
495 return std::get<1>(self->unpack());
496 })
497 .def("unpack", &LinearPackedParamsBase::unpack);
498 // (1) we can't (easily) return the static initializer itself because it can have a different type because of selective build
499 // (2) we can't return void and be able to call the function in the global scope
500 return 0;
501 }
502
503
register_embedding_params()504 int register_embedding_params() {
505 // Type for __getstate__/__setstate__ serialization
506 //
507 // Element 0 is the version of the PackedParam structure
508 // Element 1 is the Tensors contained in the Param instance
509 // Element 2 is the double values (if any) contained in the Param instance
510 // Element 3 is the int values (if any) contained in the Param instance
511
512 using EmbeddingParamsSerializationType = std::tuple<
513 int64_t, // version
514 std::vector<at::Tensor>,
515 std::vector<double>,
516 std::vector<int64_t>>;
517
518 static auto register_embedding_params =
519 torch::selective_class_<EmbeddingPackedParamsBase>(
520 "quantized", TORCH_SELECTIVE_CLASS("EmbeddingPackedParamsBase"))
521 .def_pickle(
522 [](const c10::intrusive_ptr<EmbeddingPackedParamsBase>& params)
523 -> EmbeddingParamsSerializationType { // __getstate__ call
524 at::Tensor weight = params->unpack();
525 std::vector<at::Tensor> tensors_to_serialize = {std::move(weight)};
526 std::vector<double> doubles_to_serialize = {};
527 int64_t bit_rate = params->bit_rate();
528 int64_t version = params->version();
529 std::vector<int64_t> longs_to_serialize = {bit_rate};
530 return EmbeddingParamsSerializationType(
531 version,
532 std::move(tensors_to_serialize),
533 std::move(doubles_to_serialize),
534 std::move(longs_to_serialize));
535 },
536 [](EmbeddingParamsSerializationType state)
537 -> c10::intrusive_ptr<EmbeddingPackedParamsBase> { // __setstate__ call
538
539 auto [version, tensors, doubles, longs] = std::move(state);
540
541 TORCH_INTERNAL_ASSERT(tensors.size() == 1, "EmbeddingPackedParams: Expected weight tensor to be serialized");
542 TORCH_INTERNAL_ASSERT(longs.size() == 1, "EmbeddingPackedParams: Expected bit_rate to be serialized");
543 TORCH_CHECK(version == 1, "EmbeddingPackedParams: Currently only version 1 supported.");
544
545 at::Tensor weight = std::move(tensors[0]);
546 return PackedEmbeddingBagWeight::prepack(std::move(weight));
547 })
548 .def("bit_rate", &EmbeddingPackedParamsBase::bit_rate)
549 .def("unpack", &EmbeddingPackedParamsBase::unpack)
550 .def("version", &EmbeddingPackedParamsBase::version);
551
552 return 0;
553 }
554
555 namespace {
556
557 static C10_UNUSED auto conv2d_params = register_conv_params<2>();
558 static C10_UNUSED auto conv3d_params = register_conv_params<3>();
559 static C10_UNUSED auto linear_params = register_linear_params();
560 static C10_UNUSED auto embedding_params = register_embedding_params();
561
562 } // namespace
563