xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/TensorBody.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type_base.h>
8 #include <ATen/native/quantized/PackedParams.h>
9 #include <ATen/native/quantized/cpu/conv_serialization.h>
10 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
11 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
12 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
13 #include <ATen/native/quantized/cpu/OnednnUtils.h>
14 #include <ATen/native/TensorFactories.h>
15 #include <ATen/quantized/QTensorImpl.h>
16 #include <ATen/quantized/Quantizer.h>
17 #include <c10/core/QScheme.h>
18 #include <c10/core/TensorOptions.h>
19 #include <c10/util/accumulate.h>
20 #include <c10/util/irange.h>
21 #include <torch/custom_class.h>
22 
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #else
26 #include <ATen/ops/cat.h>
27 
28 #include <utility>
29 #endif
30 
31 int register_embedding_params();
32 
33 #ifdef USE_FBGEMM
34 
35 namespace at {
36 namespace native {
37 namespace fbgemm_utils {
38 
39 namespace {
40 
IsChannelsLast3d(const Tensor & tensor)41 bool IsChannelsLast3d(const Tensor& tensor) {
42   if (tensor.dim() != 5) {
43     return false;
44   }
45   const int64_t C = tensor.size(1);
46   const int64_t D = tensor.size(2);
47   const int64_t H = tensor.size(3);
48   const int64_t W = tensor.size(4);
49   return tensor.stride(0) == D * H * W * C && tensor.stride(1) == 1 &&
50       tensor.stride(2) == H * W * C && tensor.stride(3) == W * C &&
51       tensor.stride(4) == C;
52 }
53 
54 template <typename T>
CopyToChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const T * src,T * dst)55 void CopyToChannelsLast3dTensor(
56     int64_t N,
57     int64_t C,
58     int64_t D,
59     int64_t H,
60     int64_t W,
61     const T* src,
62     T* dst) {
63   const int64_t inner_size = D * H * W;
64   for (const auto i : c10::irange(N)) {
65     for (const auto j : c10::irange(inner_size)) {
66       for (const auto k : c10::irange(C)) {
67         dst[(i * inner_size + j) * C + k] = src[(i * C + k) * inner_size + j];
68       }
69     }
70   }
71 }
72 
73 template <typename T>
CopyICFirst3dTensorToChannelsLast3dTensor(int64_t G,int64_t IC_G,int64_t OC_G,int64_t D,int64_t H,int64_t W,const T * src,T * dst)74 void CopyICFirst3dTensorToChannelsLast3dTensor(
75     int64_t G,
76     int64_t IC_G,
77     int64_t OC_G,
78     int64_t D,
79     int64_t H,
80     int64_t W,
81     const T* src,
82     T* dst) {
83   // IC OC/G THW -> G OC/G THW IC/G
84   const int64_t inner_size = D * H * W;
85   for (int64_t i = 0; i < G * OC_G; ++i) {
86     for (const auto j : c10::irange(inner_size)) {
87       for (const auto ic : c10::irange(IC_G)) {
88         // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
89         int g = i / OC_G;
90         // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
91         int oc = i % OC_G;
92         dst[(i * inner_size + j) * IC_G + ic] =
93             src[((g * IC_G + ic) * OC_G + oc) * inner_size + j];
94       }
95     }
96   }
97 }
98 
99 } // namespace
100 
101 template <int kSpatialDim>
MakeFbgemmConvParam(int N,int C,int M,const std::vector<int> & image_shape,int groups,const std::vector<int> & kernels,const std::vector<int> & strides,const std::vector<int> & pads,const std::vector<int> & dilations,const std::vector<int> & output_padding,bool transposed)102 fbgemm::conv_param_t<kSpatialDim> MakeFbgemmConvParam(
103     int N,
104     int C,
105     int M,
106     const std::vector<int>& image_shape,
107     int groups,
108     const std::vector<int>& kernels,
109     const std::vector<int>& strides,
110     const std::vector<int>& pads,
111     const std::vector<int>& dilations,
112     const std::vector<int>& output_padding,
113     bool transposed) {
114   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
115   std::array<int, kSpatialDim> image_shape_;
116   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
117   std::array<int, kSpatialDim> kernels_;
118   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
119   std::array<int, kSpatialDim> strides_;
120   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
121   std::array<int, kSpatialDim * 2> pads_;
122   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
123   std::array<int, kSpatialDim> dilations_;
124   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
125   std::array<int, kSpatialDim> output_padding_;
126   std::move(image_shape.begin(), image_shape.begin() + image_shape.size(), image_shape_.begin());
127   std::move(
128       kernels.begin(), kernels.begin() + kernels.size(), kernels_.begin());
129   std::move(
130       strides.begin(), strides.begin() + strides.size(), strides_.begin());
131   std::move(
132       dilations.begin(),
133       dilations.begin() + dilations.size(),
134       dilations_.begin());
135   std::move(
136       output_padding.begin(),
137       output_padding.begin() + output_padding.size(),
138       output_padding_.begin());
139   std::copy(pads.begin(), pads.begin() + pads.size(), pads_.begin());
140   std::move(pads.begin(), pads.begin() + pads.size(), pads_.begin() + pads.size());
141 
142   return fbgemm::conv_param_t<kSpatialDim>(
143       N, // batch size
144       C, // input channels
145       M, // output channels
146       image_shape_, // feature map size
147       groups, // groups
148       kernels_, // kernels
149       strides_, // strides
150       pads_, // paddings
151       dilations_, // dilations
152       output_padding_, // output paddings for conv transpose
153       transposed);
154 }
155 
MakeStridedQTensorCPU(const IntArrayRef & sizes,const IntArrayRef & strides,const TensorOptions & options,QuantizerPtr quantizer)156 Tensor MakeStridedQTensorCPU(
157     const IntArrayRef& sizes,
158     const IntArrayRef& strides,
159     const TensorOptions& options,
160     QuantizerPtr quantizer) {
161   AT_ASSERT(options.device().is_cpu());
162   at::native::check_size_nonnegative(sizes);
163   auto* allocator = at::getCPUAllocator();
164   const int64_t nelements = c10::multiply_integers(sizes);
165   auto dtype = options.dtype();
166   TORCH_CHECK(
167       isQIntType(typeMetaToScalarType(dtype)),
168       "ScalarType is not supported in new_qtensor_cpu.");
169   int64_t size_bytes = nelements * dtype.itemsize();
170   auto storage = c10::make_intrusive<StorageImpl>(
171       StorageImpl::use_byte_size_t(),
172       size_bytes,
173       allocator->allocate(size_bytes),
174       allocator,
175       /* resizable = */ true);
176   constexpr auto quantized_cpu_ks = at::DispatchKeySet(at::DispatchKey::QuantizedCPU);
177   auto tensor = detail::make_tensor<QTensorImpl>(
178       storage,
179       quantized_cpu_ks,
180       dtype,
181       quantizer);
182   get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides);
183   return tensor;
184 }
185 
MakeEmptyAffineQuantizedChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const TensorOptions & options,double scale,int64_t zero_point)186 Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor(
187     int64_t N,
188     int64_t C,
189     int64_t D,
190     int64_t H,
191     int64_t W,
192     const TensorOptions& options,
193     double scale,
194     int64_t zero_point) {
195   return MakeStridedQTensorCPU(
196       {N, C, D, H, W},
197       {D * H * W * C, 1, H * W * C, W * C, C},
198       options,
199       make_per_tensor_affine_quantizer(
200           scale, zero_point, typeMetaToScalarType(options.dtype())));
201 }
202 
MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,const TensorOptions & options,const Tensor & scales,const Tensor & zero_points)203 Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(
204     int64_t N,
205     int64_t C,
206     int64_t D,
207     int64_t H,
208     int64_t W,
209     const TensorOptions& options,
210     const Tensor& scales,
211     const Tensor& zero_points) {
212   return MakeStridedQTensorCPU(
213       {N, C, D, H, W},
214       {D * H * W * C, 1, H * W * C, W * C, C},
215       options,
216       make_per_channel_affine_quantizer(
217           scales,
218           zero_points,
219           0, // axis
220           typeMetaToScalarType(options.dtype())));
221 }
222 
ConvertToChannelsLast3dTensor(const Tensor & src)223 Tensor ConvertToChannelsLast3dTensor(const Tensor& src) {
224   TORCH_CHECK(src.dim() == 5);
225   Tensor dst;
226   if (IsChannelsLast3d(src)) {
227     dst = src;
228   } else {
229     const int64_t N = src.size(0);
230     const int64_t C = src.size(1);
231     const int64_t D = src.size(2);
232     const int64_t H = src.size(3);
233     const int64_t W = src.size(4);
234     dst = MakeStridedQTensorCPU(
235         {N, C, D, H, W},
236         {D * H * W * C, 1, H * W * C, W * C, C},
237         src.options(),
238         src.quantizer());
239     AT_DISPATCH_QINT_TYPES(
240         src.scalar_type(), "ConvertToChannelsLast3dTensor", [&]() {
241           const Tensor src_contig = src.contiguous();
242           CopyToChannelsLast3dTensor<scalar_t>(
243               N,
244               C,
245               D,
246               H,
247               W,
248               src_contig.data_ptr<scalar_t>(),
249               dst.data_ptr<scalar_t>());
250         });
251   }
252   return dst;
253 }
254 
255 template <>
TransposeConvTensorUnpackConversion(const Tensor & src,int groups)256 Tensor TransposeConvTensorUnpackConversion<2>(const Tensor& src, int groups) {
257   // OC IC/G HW -> IC OC/G HW logically
258   auto oc_g_ic_g_hw_tensors = src.chunk(groups);
259   auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1);
260   set_quantizer_(fused_tensor, src.quantizer());
261   return fused_tensor.permute({1, 0, 2, 3});
262 }
263 
264 template fbgemm::conv_param_t<1> MakeFbgemmConvParam<1>(
265     int N,
266     int C,
267     int M,
268     const std::vector<int>& image_shape,
269     int groups,
270     const std::vector<int>& kernels,
271     const std::vector<int>& strides,
272     const std::vector<int>& pads,
273     const std::vector<int>& dilations,
274     const std::vector<int>& output_padding,
275     bool transposed);
276 
277 template fbgemm::conv_param_t<2> MakeFbgemmConvParam<2>(
278     int N,
279     int C,
280     int M,
281     const std::vector<int>& image_shape,
282     int groups,
283     const std::vector<int>& kernels,
284     const std::vector<int>& strides,
285     const std::vector<int>& pads,
286     const std::vector<int>& dilations,
287     const std::vector<int>& output_padding,
288     bool transposed);
289 
290 template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>(
291     int N,
292     int C,
293     int M,
294     const std::vector<int>& image_shape,
295     int groups,
296     const std::vector<int>& kernels,
297     const std::vector<int>& strides,
298     const std::vector<int>& pads,
299     const std::vector<int>& dilations,
300     const std::vector<int>& output_padding,
301     bool transposed);
302 template <>
TransposeConvTensorUnpackConversion(const Tensor & src,int groups)303 Tensor TransposeConvTensorUnpackConversion<3>(const Tensor& src, int groups) {
304   // OC IC/G DHW -> IC OC/G DHW logically
305   auto oc_g_ic_g_hw_tensors = src.chunk(groups);
306   auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1);
307   set_quantizer_(fused_tensor, src.quantizer());
308   return fused_tensor.permute({1, 0, 2, 3, 4});
309 }
310 
311 template <>
ConvertConvWeightsToChannelLastTensor(const at::Tensor & src,int groups,bool transpose)312 Tensor ConvertConvWeightsToChannelLastTensor<2>(
313     const at::Tensor& src,
314     int groups,
315     bool transpose) {
316   return transpose ?
317                    // 2D conv transpose weight transform
318                    // IC OC/G KH KW -> G OC/G KH KW IC/G
319       [&]() {
320         auto ic_g_oc_g_hw_tensors = src.chunk(groups);
321         for (auto& tensor : ic_g_oc_g_hw_tensors) {
322           tensor = tensor.unsqueeze(0);
323         }
324         auto fused_tensor = at::cat(ic_g_oc_g_hw_tensors);
325         set_quantizer_(fused_tensor, src.quantizer());
326         return fused_tensor.permute({0, 2, 3, 4, 1})
327             .contiguous(c10::MemoryFormat::Contiguous);
328       }()
329                    // 2d conv weight transform
330                    : src.contiguous(c10::MemoryFormat::ChannelsLast);
331 }
332 
333 template <>
ConvertConvWeightsToChannelLastTensor(const at::Tensor & src,int groups,bool transpose)334 Tensor ConvertConvWeightsToChannelLastTensor<3>(
335     const at::Tensor& src,
336     int groups,
337     bool transpose) {
338   if (!transpose) {
339     return ConvertToChannelsLast3dTensor(src);
340   } else {
341     TORCH_CHECK(src.dim() == 5);
342     Tensor dst;
343     const int64_t N = src.size(0);
344     const int64_t IC_G = N / groups;
345     const int64_t OC_G = src.size(1);
346     const int64_t D = src.size(2);
347     const int64_t H = src.size(3);
348     const int64_t W = src.size(4);
349     dst = MakeStridedQTensorCPU(
350         {groups * OC_G, IC_G, D, H, W},
351         {D * H * W * IC_G, 1, H * W * IC_G, W * IC_G, IC_G},
352         src.options(),
353         src.quantizer());
354     AT_DISPATCH_QINT_TYPES(
355         src.scalar_type(), "CopyICFirst3dTensorToChannelsLast3dTensor", [&]() {
356           const Tensor src_contig = src.contiguous();
357           CopyICFirst3dTensorToChannelsLast3dTensor<scalar_t>(
358               groups,
359               IC_G,
360               OC_G,
361               D,
362               H,
363               W,
364               src_contig.data_ptr<scalar_t>(),
365               dst.data_ptr<scalar_t>());
366         });
367     return dst;
368   }
369 }
370 
371 } // namespace fbgemm_utils
372 } // namespace native
373 } // namespace at
374 
375 
376 #endif // USE_FBGEMM
377 
378 namespace {
379   // This is really terrible, but couldnt figure out a better way to constexpr convert int to
380   // string and then perform string concatenation on/with it
_hack_int_to_class_name(int x)381   constexpr const char* _hack_int_to_class_name(int x) {
382     switch(x) {
383       case 2:
384         return "Conv2dPackedParamsBase";
385       case 3:
386         return "Conv3dPackedParamsBase";
387       default:
388         assert(false);
389         return "NotAValidDimension";
390     }
391   }
392 }
393 
394 template <int kSpatialDim = 2>
395 TORCH_API int
register_conv_params()396 register_conv_params() {
397   static auto register_conv_params =
398     torch::selective_class_<ConvPackedParamsBase<kSpatialDim>>(
399         "quantized", TORCH_SELECTIVE_CLASS(_hack_int_to_class_name(kSpatialDim)))
400     .def_pickle(
401         [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params)
402         -> ConvParamsSerializationType { // __getstate__
403           return serialize_conv<kSpatialDim>(params);
404         },
405         // __setstate__ takes c10::IValue because we support parsing historical
406         // serialization versions.
407         [](c10::IValue v)
408         -> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> { // __setstate__
409           ConvParamsSerializationTypeV3 state = parse_conv_serialized_state<kSpatialDim>(v);
410           return deserialize_conv<kSpatialDim>(state);
411         })
412     .def("weight", [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& self) {
413                      return std::get<0>(self->unpack());
414                    })
415     .def("bias", [](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& self) {
416                      return std::get<1>(self->unpack());
417                  })
418     .def("unpack", &ConvPackedParamsBase<kSpatialDim>::unpack)
419     .def("stride", &ConvPackedParamsBase<kSpatialDim>::stride)
420     .def("padding", &ConvPackedParamsBase<kSpatialDim>::padding)
421     .def("output_padding", &ConvPackedParamsBase<kSpatialDim>::output_padding)
422     .def("dilation", &ConvPackedParamsBase<kSpatialDim>::dilation)
423     .def("groups", &ConvPackedParamsBase<kSpatialDim>::groups)
424     .def("transpose", &ConvPackedParamsBase<kSpatialDim>::transpose);
425   return 0;
426 }
427 
428 template
429 TORCH_API int register_conv_params<2>();
430 template
431 TORCH_API int register_conv_params<3>();
432 
433 TORCH_API int register_linear_params();
434 
register_linear_params()435 TORCH_API int register_linear_params() {
436   using SerializationType = std::tuple<at::Tensor, std::optional<at::Tensor>>;
437   static auto register_linear_params =
438       torch::selective_class_<LinearPackedParamsBase>(
439           "quantized", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase"))
440           .def_pickle(
441               [](const c10::intrusive_ptr<LinearPackedParamsBase>& params)
442                   -> SerializationType { // __getstate__
443                 return params->unpack();
444               },
445               [](SerializationType state)
446                   -> c10::intrusive_ptr<
447                       LinearPackedParamsBase> { // __setstate__
448                 at::Tensor weight;
449                 std::optional<at::Tensor> bias;
450                 weight = std::move(std::get<0>(state));
451                 bias = std::move(std::get<1>(state));
452 
453 #ifdef USE_FBGEMM
454                 if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
455                     at::globalContext().qEngine() == at::QEngine::X86) {
456                   if (weight.scalar_type() == at::kQInt8) {
457                     return PackedLinearWeight::prepack(
458                         std::move(weight), std::move(bias));
459                   } else if (weight.scalar_type() == at::kFloat) {
460                     // NB: fp16 weight is serialized as float
461                     return PackedLinearWeightFp16::prepack(
462                         std::move(weight), std::move(bias));
463                   } else {
464                     TORCH_CHECK(
465                         false,
466                         "Unsupported data type",
467                         c10::toString(weight.scalar_type()),
468                         " in serialized LinearPackedParams object!");
469                   }
470                 }
471 #endif // USE_FBGEMM
472 #ifdef USE_PYTORCH_QNNPACK
473                 if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
474                   TORCH_CHECK(
475                       weight.scalar_type() == at::kQInt8,
476                       "QNNPACK only supports INT8 bit width currently. Got ",
477                       c10::toString(weight.scalar_type()));
478                   return PackedLinearWeightsQnnp::prepack(
479                       std::move(weight), std::move(bias));
480                 }
481 #endif // USE_PYTORCH_QNNPACK
482 #if AT_MKLDNN_ENABLED()
483                 if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
484                   TORCH_CHECK(
485                       weight.scalar_type() == at::kQInt8,
486                       "ONEDNN only supports INT8 bit width currently. Got ",
487                       c10::toString(weight.scalar_type()));
488                   return PackedLinearWeightsOnednn::prepack(
489                       std::move(weight), std::move(bias));
490                 }
491 #endif // #if AT_MKLDNN_ENABLED()
492                 TORCH_CHECK(false, "Unknown qengine");
493               })
494               .def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
495                   return std::get<1>(self->unpack());
496                  })
497               .def("unpack", &LinearPackedParamsBase::unpack);
498   // (1) we can't (easily) return the static initializer itself because it can have a different type because of selective build
499   // (2) we can't return void and be able to call the function in the global scope
500   return 0;
501 }
502 
503 
register_embedding_params()504 int register_embedding_params() {
505   // Type for __getstate__/__setstate__ serialization
506   //
507   // Element 0 is the version of the PackedParam structure
508   // Element 1 is the Tensors contained in the Param instance
509   // Element 2 is the double values (if any) contained in the Param instance
510   // Element 3 is the int values (if any) contained in the Param instance
511 
512   using EmbeddingParamsSerializationType = std::tuple<
513     int64_t, // version
514     std::vector<at::Tensor>,
515     std::vector<double>,
516     std::vector<int64_t>>;
517 
518   static auto register_embedding_params =
519     torch::selective_class_<EmbeddingPackedParamsBase>(
520       "quantized", TORCH_SELECTIVE_CLASS("EmbeddingPackedParamsBase"))
521       .def_pickle(
522           [](const c10::intrusive_ptr<EmbeddingPackedParamsBase>& params)
523               -> EmbeddingParamsSerializationType { // __getstate__ call
524             at::Tensor weight = params->unpack();
525             std::vector<at::Tensor> tensors_to_serialize = {std::move(weight)};
526             std::vector<double> doubles_to_serialize = {};
527             int64_t bit_rate = params->bit_rate();
528             int64_t version = params->version();
529             std::vector<int64_t> longs_to_serialize = {bit_rate};
530             return EmbeddingParamsSerializationType(
531               version,
532               std::move(tensors_to_serialize),
533               std::move(doubles_to_serialize),
534               std::move(longs_to_serialize));
535           },
536           [](EmbeddingParamsSerializationType state)
537               -> c10::intrusive_ptr<EmbeddingPackedParamsBase> { // __setstate__ call
538 
539             auto [version, tensors, doubles, longs] = std::move(state);
540 
541             TORCH_INTERNAL_ASSERT(tensors.size() == 1, "EmbeddingPackedParams: Expected weight tensor to be serialized");
542             TORCH_INTERNAL_ASSERT(longs.size() == 1, "EmbeddingPackedParams: Expected bit_rate to be serialized");
543             TORCH_CHECK(version == 1, "EmbeddingPackedParams: Currently only version 1 supported.");
544 
545             at::Tensor weight = std::move(tensors[0]);
546             return PackedEmbeddingBagWeight::prepack(std::move(weight));
547           })
548       .def("bit_rate", &EmbeddingPackedParamsBase::bit_rate)
549       .def("unpack", &EmbeddingPackedParamsBase::unpack)
550       .def("version", &EmbeddingPackedParamsBase::version);
551 
552   return 0;
553 }
554 
555 namespace {
556 
557 static C10_UNUSED auto conv2d_params = register_conv_params<2>();
558 static C10_UNUSED auto conv3d_params = register_conv_params<3>();
559 static C10_UNUSED auto linear_params = register_linear_params();
560 static C10_UNUSED auto embedding_params = register_embedding_params();
561 
562 } // namespace
563