xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/XnnpackUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <ATen/ATen.h>
4 #include <ATen/quantized/Quantizer.h>
5 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
6 #include <c10/util/irange.h>
7 
8 namespace at {
9 namespace native {
10 namespace xnnp_utils {
11 
get_mem_format_aware_shape(const at::Tensor & in)12 std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in) {
13   const auto mem_format = in.suggest_memory_format();
14   const auto& sizes = in.sizes();
15   std::vector<size_t> ret(sizes.begin(), sizes.end());
16   if (mem_format == c10::MemoryFormat::ChannelsLast) {
17     // NCHW -> NHWC
18     // 0123 -> 0231
19     ret[1] = sizes[2]; /* H */
20     ret[2] = sizes[3]; /* W */
21     ret[3] = sizes[1]; /* C */
22   } else if (mem_format == c10::MemoryFormat::ChannelsLast3d) {
23     // NCDHW -> NDHWC
24     // 01234 -> 02341
25     ret[1] = sizes[2]; /* D */
26     ret[2] = sizes[3]; /* H */
27     ret[3] = sizes[4]; /* W */
28     ret[4] = sizes[1]; /* C */
29   }
30   return ret;
31 }
32 
33 template <typename PT>
q8_copy_int8_weight_and_add_offset(const at::Tensor & in,at::Tensor & out)34 void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out) {
35   using T = typename PT::underlying;
36   static constexpr auto offset = std::is_same<T, uint8_t>::value ? 128 : 0;
37   TORCH_CHECK(
38       in.scalar_type() == c10::kQInt8,
39       "q8_copy_int8_weight_and_add_offset: Expected input weight data type ",
40       toString(c10::kQInt8),
41       " but got ",
42       toString(in.scalar_type()))
43   const int8_t* in_ptr =
44       reinterpret_cast<const int8_t*>(in.data_ptr<c10::qint8>());
45   T* out_ptr = reinterpret_cast<T*>(out.data_ptr<PT>());
46 
47   for (const auto i : c10::irange(in.numel())) {
48     out_ptr[i] = static_cast<T>(static_cast<int32_t>(in_ptr[i]) + offset);
49   }
50 }
51 
52 template void q8_copy_int8_weight_and_add_offset<c10::quint8>(
53     const at::Tensor& in,
54     at::Tensor& out);
55 template void q8_copy_int8_weight_and_add_offset<c10::qint8>(
56     const at::Tensor& in,
57     at::Tensor& out);
58 
59 /*
60  * Stolen from fbgemm_utils::ConvertConvWeightsToChannelLastTensor to avoid
61  * dependence on USE_FBGEMM. Reorder weights to the format xnnpack expects.
62  * TODO: add a 3d variant.
63  */
64 template <>
convert_conv_weights_to_channel_last_tensor(const at::Tensor & src,int groups,bool transpose)65 Tensor convert_conv_weights_to_channel_last_tensor<2>(
66     const at::Tensor& src,
67     int groups,
68     bool transpose) {
69   return transpose ?
70                    // 2D conv transpose weight transform
71                    // IC OC/G KH KW -> G OC/G KH KW IC/G
72       [&]() {
73         auto ic_g_oc_g_hw_tensors = src.chunk(groups);
74         for (auto& tensor : ic_g_oc_g_hw_tensors) {
75           tensor = tensor.unsqueeze(0);
76         }
77         auto fused_tensor = at::cat(ic_g_oc_g_hw_tensors);
78         set_quantizer_(fused_tensor, src.quantizer());
79         return fused_tensor.permute({0, 2, 3, 4, 1})
80             .contiguous(c10::MemoryFormat::Contiguous);
81       }()
82                    // 2d conv weight transform
83                    : src.contiguous(c10::MemoryFormat::ChannelsLast);
84 }
85 } // namespace xnnp_utils
86 } // namespace native
87 } // namespace at
88 
89 #endif // USE_XNNPACK
90