xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/TensorShape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/WrapDimUtils.h>
6 #include <ATen/core/IListRef.h>
7 #include <ATen/native/cpu/Loops.h>
8 #include <ATen/native/quantized/cpu/QuantizedOps.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/TensorShape.h>
11 #include <c10/util/irange.h>
12 #include <torch/library.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cat.h>
19 #include <ATen/ops/cat_native.h>
20 #include <ATen/ops/copy_native.h>
21 #include <ATen/ops/quantize_per_tensor.h>
22 #include <ATen/ops/zeros_like_ops.h>
23 #endif
24 
25 #include <algorithm>
26 #include <vector>
27 
28 namespace at::native {
29 
30 DEFINE_DISPATCH(qcat_nhwc_stub);
31 DEFINE_DISPATCH(qcat_relu_nhwc_stub);
32 
33 namespace {
34 
is_cat_nhwc_fast_path(const MaterializedITensorListRef & qxs,int64_t dim)35 bool is_cat_nhwc_fast_path(const MaterializedITensorListRef& qxs, int64_t dim) {
36   TORCH_CHECK(!qxs.empty());
37   bool is_fast_path = dim == 1;
38   // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
39   for (const at::Tensor& qx : qxs) {
40     is_fast_path &= qx.dim() == 4;
41     is_fast_path &= qx.is_contiguous(c10::MemoryFormat::ChannelsLast);
42   }
43   return is_fast_path;
44 }
45 
is_valid_quantization_scheme(const Tensor & t)46 bool is_valid_quantization_scheme(const Tensor& t) {
47   const auto qtype = t.qscheme();
48   return (qtype == kPerTensorAffine) || (qtype == kPerTensorSymmetric);
49 }
50 
51 #define QPARAM_THRESHOLD 1e-04
52 
all_inputs_sharing_qparams(const MaterializedITensorListRef & qxs)53 bool all_inputs_sharing_qparams(const MaterializedITensorListRef& qxs) {
54   bool is_valid = true;
55   for (const auto i : c10::irange(1, qxs.size())) {
56     is_valid &= qxs[0].get().is_quantized();
57     is_valid &= qxs[i].get().is_quantized() == qxs[0].get().is_quantized();
58     is_valid &= qxs[i].get().qscheme() == qxs[0].get().qscheme();
59     is_valid &= qxs[i].get().dtype() == qxs[0].get().dtype();
60     if (qxs[0].get().qscheme() == kPerTensorAffine) {
61         is_valid &= fabs(qxs[i].get().q_scale() - qxs[0].get().q_scale()) < QPARAM_THRESHOLD;
62       is_valid &= qxs[i].get().q_zero_point() == qxs[0].get().q_zero_point();
63     } else if (qxs[0].get().qscheme() == kPerChannelAffine) {
64         is_valid &= qxs[i].get().q_per_channel_scales().isclose(qxs[0].get().q_per_channel_scales(), 0, QPARAM_THRESHOLD, false).all().item().to<bool>();
65       is_valid &= qxs[i].get().q_per_channel_zero_points().equal(qxs[0].get().q_per_channel_zero_points());
66     } else {
67         TORCH_CHECK(false, "Unrecognized qscheme:", toString(qxs[0].get().qscheme()));
68     }
69   }
70   return is_valid;
71 }
72 
73 /* Quantized concatenation.
74  *
75  * Note: This function uses a dequantization.
76  */
77 template <bool ReLUFused>
quantized_cat_impl(const MaterializedITensorListRef & qxs,int64_t dim,double scale,int64_t zero_point)78 Tensor quantized_cat_impl(
79     const MaterializedITensorListRef& qxs,
80     int64_t dim,
81     double scale,
82     int64_t zero_point) {
83   if (is_cat_nhwc_fast_path(qxs, dim)) {
84     if (ReLUFused) {
85       return qcat_relu_nhwc_stub(at::kCPU, qxs, dim, scale, zero_point);
86     } else {
87       return qcat_nhwc_stub(at::kCPU, qxs, dim, scale, zero_point);
88     }
89   }
90 
91   const auto x_dtype = qxs[0].get().scalar_type();
92   const auto x_qscheme = qxs[0].get().qscheme();
93   std::vector<Tensor> xs;
94   xs.reserve(qxs.size());
95   // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
96   for (const at::Tensor& qx : qxs) {
97     TORCH_CHECK(x_dtype == qx.scalar_type(), "All dtypes must be the same.");
98     TORCH_CHECK(
99         x_qscheme == qx.qscheme(), "Quantization schemes must be the same.");
100     xs.push_back(qx.dequantize());
101   }
102   const Tensor y = at::cat(xs, dim);
103   Tensor qy;
104   AT_DISPATCH_QINT_TYPES(x_dtype, "qcat", [&]() {
105     // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
106     qy = at::quantize_per_tensor(y, scale, zero_point, SCALAR_TYPE);
107     if (ReLUFused) {
108       auto iter = TensorIterator::unary_op(qy, qy);
109       cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
110         return scalar_t(std::max<underlying_t>(value.val_, zero_point));
111       });
112     }
113   });
114   return qy;
115 }
116 
117 template <bool ReLUFused>
quantized_cat_impl(ITensorListRef qxs,int64_t dim,double scale,int64_t zero_point)118 Tensor quantized_cat_impl(
119     ITensorListRef qxs,
120     int64_t dim,
121     double scale,
122     int64_t zero_point) {
123   return quantized_cat_impl<ReLUFused>(qxs.materialize(), dim, scale, zero_point);
124 }
125 
126 template <bool ReLUFused = false>
qcat(const c10::List<Tensor> & qxs,int64_t dim,std::optional<double> scale,std::optional<int64_t> zero_point)127 Tensor qcat(
128     const c10::List<Tensor>& qxs,
129     int64_t dim,
130     std::optional<double> scale,
131     std::optional<int64_t> zero_point) {
132   TORCH_CHECK(is_valid_quantization_scheme(qxs[0]),
133               "Only per-tensor quantization is supported in 'cat'!")
134   double _scale = scale.has_value() ? scale.value() : qxs.get(0).q_scale();
135   int64_t _zero_point =
136       zero_point.has_value() ? zero_point.value() : qxs.get(0).q_zero_point();
137   return quantized_cat_impl<ReLUFused>(qxs, dim, _scale, _zero_point);
138 }
139 
140 template <bool ReLUFused = false>
qcat_out(const c10::List<Tensor> & qxs,int64_t dim,Tensor out)141 Tensor qcat_out(const c10::List<Tensor>& qxs, int64_t dim, Tensor out) {
142   TORCH_CHECK(is_valid_quantization_scheme(qxs[0]),
143               "Only per-tensor quantization is supported in 'cat'!")
144   TORCH_CHECK(is_valid_quantization_scheme(out),
145               "Only per-tensor quantization is supported in 'cat'!")
146   auto out_ =
147       quantized_cat_impl<ReLUFused>(qxs, dim, out.q_scale(), out.q_zero_point());
148   at::native::copy_(out, out_, /*non_blocking=*/false);
149   return out;
150 }
151 
152 } // namespace
153 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)154 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
155   m.impl(TORCH_SELECTIVE_NAME("quantized::cat"), TORCH_FN(qcat<false>));
156   m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu"), TORCH_FN(qcat<true>));
157   m.impl(TORCH_SELECTIVE_NAME("quantized::cat_out"), TORCH_FN(qcat_out<false>));
158   m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu_out"), TORCH_FN(qcat_out<true>));
159 }
160 
cat_quantized_cpu(const ITensorListRef & qxs,int64_t dim)161 Tensor cat_quantized_cpu(const ITensorListRef& qxs, int64_t dim) {
162   auto materialized = qxs.materialize();
163   TORCH_CHECK(is_valid_quantization_scheme(materialized[0]),
164               "Only per-tensor quantization is supported in 'cat'!");
165 
166   if (!all_inputs_sharing_qparams(materialized)) {
167       // TODO: if possible change this warning to an error T194501002
168       TORCH_WARN("All inputs of this cat operator must share the same quantization parameters. Otherwise large numerical inaccuracies may occur.");
169   }
170   check_cat_no_zero_dim(materialized);
171   dim = legacy_cat_wrap_dim(dim, materialized);
172   double _scale = materialized[0].get().q_scale();
173   int64_t _zero_point = materialized[0].get().q_zero_point();
174   return quantized_cat_impl<false>(materialized, dim, _scale, _zero_point);
175 }
176 
cat_out_quantized_cpu(const ITensorListRef & qxs,int64_t dim,Tensor & out)177 Tensor& cat_out_quantized_cpu(const ITensorListRef& qxs, int64_t dim, Tensor& out) {
178   auto materialized = qxs.materialize();
179   TORCH_CHECK(is_valid_quantization_scheme(materialized[0]),
180               "Only per-tensor quantization is supported in 'cat'!")
181   TORCH_CHECK(is_valid_quantization_scheme(out),
182               "Only per-tensor quantization is supported in 'cat'!")
183   check_cat_no_zero_dim(materialized);
184   dim = legacy_cat_wrap_dim(dim, materialized);
185   auto out_ = quantized_cat_impl<false>(qxs, dim, out.q_scale(), out.q_zero_point());
186   at::native::copy_(out, out_, /*non_blocking=*/false);
187   return out;
188 }
189 
190 }  // namespace at::native
191