1 #ifdef USE_XNNPACK
2
3 #include <ATen/ATen.h>
4 #include <ATen/quantized/Quantizer.h>
5 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
6 #include <c10/util/irange.h>
7
8 namespace at {
9 namespace native {
10 namespace xnnp_utils {
11
get_mem_format_aware_shape(const at::Tensor & in)12 std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in) {
13 const auto mem_format = in.suggest_memory_format();
14 const auto& sizes = in.sizes();
15 std::vector<size_t> ret(sizes.begin(), sizes.end());
16 if (mem_format == c10::MemoryFormat::ChannelsLast) {
17 // NCHW -> NHWC
18 // 0123 -> 0231
19 ret[1] = sizes[2]; /* H */
20 ret[2] = sizes[3]; /* W */
21 ret[3] = sizes[1]; /* C */
22 } else if (mem_format == c10::MemoryFormat::ChannelsLast3d) {
23 // NCDHW -> NDHWC
24 // 01234 -> 02341
25 ret[1] = sizes[2]; /* D */
26 ret[2] = sizes[3]; /* H */
27 ret[3] = sizes[4]; /* W */
28 ret[4] = sizes[1]; /* C */
29 }
30 return ret;
31 }
32
33 template <typename PT>
q8_copy_int8_weight_and_add_offset(const at::Tensor & in,at::Tensor & out)34 void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out) {
35 using T = typename PT::underlying;
36 static constexpr auto offset = std::is_same<T, uint8_t>::value ? 128 : 0;
37 TORCH_CHECK(
38 in.scalar_type() == c10::kQInt8,
39 "q8_copy_int8_weight_and_add_offset: Expected input weight data type ",
40 toString(c10::kQInt8),
41 " but got ",
42 toString(in.scalar_type()))
43 const int8_t* in_ptr =
44 reinterpret_cast<const int8_t*>(in.data_ptr<c10::qint8>());
45 T* out_ptr = reinterpret_cast<T*>(out.data_ptr<PT>());
46
47 for (const auto i : c10::irange(in.numel())) {
48 out_ptr[i] = static_cast<T>(static_cast<int32_t>(in_ptr[i]) + offset);
49 }
50 }
51
52 template void q8_copy_int8_weight_and_add_offset<c10::quint8>(
53 const at::Tensor& in,
54 at::Tensor& out);
55 template void q8_copy_int8_weight_and_add_offset<c10::qint8>(
56 const at::Tensor& in,
57 at::Tensor& out);
58
59 /*
60 * Stolen from fbgemm_utils::ConvertConvWeightsToChannelLastTensor to avoid
61 * dependence on USE_FBGEMM. Reorder weights to the format xnnpack expects.
62 * TODO: add a 3d variant.
63 */
64 template <>
convert_conv_weights_to_channel_last_tensor(const at::Tensor & src,int groups,bool transpose)65 Tensor convert_conv_weights_to_channel_last_tensor<2>(
66 const at::Tensor& src,
67 int groups,
68 bool transpose) {
69 return transpose ?
70 // 2D conv transpose weight transform
71 // IC OC/G KH KW -> G OC/G KH KW IC/G
72 [&]() {
73 auto ic_g_oc_g_hw_tensors = src.chunk(groups);
74 for (auto& tensor : ic_g_oc_g_hw_tensors) {
75 tensor = tensor.unsqueeze(0);
76 }
77 auto fused_tensor = at::cat(ic_g_oc_g_hw_tensors);
78 set_quantizer_(fused_tensor, src.quantizer());
79 return fused_tensor.permute({0, 2, 3, 4, 1})
80 .contiguous(c10::MemoryFormat::Contiguous);
81 }()
82 // 2d conv weight transform
83 : src.contiguous(c10::MemoryFormat::ChannelsLast);
84 }
85 } // namespace xnnp_utils
86 } // namespace native
87 } // namespace at
88
89 #endif // USE_XNNPACK
90