1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/WrapDimUtils.h>
6 #include <ATen/core/IListRef.h>
7 #include <ATen/native/cpu/Loops.h>
8 #include <ATen/native/quantized/cpu/QuantizedOps.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/TensorShape.h>
11 #include <c10/util/irange.h>
12 #include <torch/library.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cat.h>
19 #include <ATen/ops/cat_native.h>
20 #include <ATen/ops/copy_native.h>
21 #include <ATen/ops/quantize_per_tensor.h>
22 #include <ATen/ops/zeros_like_ops.h>
23 #endif
24
25 #include <algorithm>
26 #include <vector>
27
28 namespace at::native {
29
30 DEFINE_DISPATCH(qcat_nhwc_stub);
31 DEFINE_DISPATCH(qcat_relu_nhwc_stub);
32
33 namespace {
34
is_cat_nhwc_fast_path(const MaterializedITensorListRef & qxs,int64_t dim)35 bool is_cat_nhwc_fast_path(const MaterializedITensorListRef& qxs, int64_t dim) {
36 TORCH_CHECK(!qxs.empty());
37 bool is_fast_path = dim == 1;
38 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
39 for (const at::Tensor& qx : qxs) {
40 is_fast_path &= qx.dim() == 4;
41 is_fast_path &= qx.is_contiguous(c10::MemoryFormat::ChannelsLast);
42 }
43 return is_fast_path;
44 }
45
is_valid_quantization_scheme(const Tensor & t)46 bool is_valid_quantization_scheme(const Tensor& t) {
47 const auto qtype = t.qscheme();
48 return (qtype == kPerTensorAffine) || (qtype == kPerTensorSymmetric);
49 }
50
51 #define QPARAM_THRESHOLD 1e-04
52
all_inputs_sharing_qparams(const MaterializedITensorListRef & qxs)53 bool all_inputs_sharing_qparams(const MaterializedITensorListRef& qxs) {
54 bool is_valid = true;
55 for (const auto i : c10::irange(1, qxs.size())) {
56 is_valid &= qxs[0].get().is_quantized();
57 is_valid &= qxs[i].get().is_quantized() == qxs[0].get().is_quantized();
58 is_valid &= qxs[i].get().qscheme() == qxs[0].get().qscheme();
59 is_valid &= qxs[i].get().dtype() == qxs[0].get().dtype();
60 if (qxs[0].get().qscheme() == kPerTensorAffine) {
61 is_valid &= fabs(qxs[i].get().q_scale() - qxs[0].get().q_scale()) < QPARAM_THRESHOLD;
62 is_valid &= qxs[i].get().q_zero_point() == qxs[0].get().q_zero_point();
63 } else if (qxs[0].get().qscheme() == kPerChannelAffine) {
64 is_valid &= qxs[i].get().q_per_channel_scales().isclose(qxs[0].get().q_per_channel_scales(), 0, QPARAM_THRESHOLD, false).all().item().to<bool>();
65 is_valid &= qxs[i].get().q_per_channel_zero_points().equal(qxs[0].get().q_per_channel_zero_points());
66 } else {
67 TORCH_CHECK(false, "Unrecognized qscheme:", toString(qxs[0].get().qscheme()));
68 }
69 }
70 return is_valid;
71 }
72
73 /* Quantized concatenation.
74 *
75 * Note: This function uses a dequantization.
76 */
77 template <bool ReLUFused>
quantized_cat_impl(const MaterializedITensorListRef & qxs,int64_t dim,double scale,int64_t zero_point)78 Tensor quantized_cat_impl(
79 const MaterializedITensorListRef& qxs,
80 int64_t dim,
81 double scale,
82 int64_t zero_point) {
83 if (is_cat_nhwc_fast_path(qxs, dim)) {
84 if (ReLUFused) {
85 return qcat_relu_nhwc_stub(at::kCPU, qxs, dim, scale, zero_point);
86 } else {
87 return qcat_nhwc_stub(at::kCPU, qxs, dim, scale, zero_point);
88 }
89 }
90
91 const auto x_dtype = qxs[0].get().scalar_type();
92 const auto x_qscheme = qxs[0].get().qscheme();
93 std::vector<Tensor> xs;
94 xs.reserve(qxs.size());
95 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
96 for (const at::Tensor& qx : qxs) {
97 TORCH_CHECK(x_dtype == qx.scalar_type(), "All dtypes must be the same.");
98 TORCH_CHECK(
99 x_qscheme == qx.qscheme(), "Quantization schemes must be the same.");
100 xs.push_back(qx.dequantize());
101 }
102 const Tensor y = at::cat(xs, dim);
103 Tensor qy;
104 AT_DISPATCH_QINT_TYPES(x_dtype, "qcat", [&]() {
105 // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
106 qy = at::quantize_per_tensor(y, scale, zero_point, SCALAR_TYPE);
107 if (ReLUFused) {
108 auto iter = TensorIterator::unary_op(qy, qy);
109 cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
110 return scalar_t(std::max<underlying_t>(value.val_, zero_point));
111 });
112 }
113 });
114 return qy;
115 }
116
117 template <bool ReLUFused>
quantized_cat_impl(ITensorListRef qxs,int64_t dim,double scale,int64_t zero_point)118 Tensor quantized_cat_impl(
119 ITensorListRef qxs,
120 int64_t dim,
121 double scale,
122 int64_t zero_point) {
123 return quantized_cat_impl<ReLUFused>(qxs.materialize(), dim, scale, zero_point);
124 }
125
126 template <bool ReLUFused = false>
qcat(const c10::List<Tensor> & qxs,int64_t dim,std::optional<double> scale,std::optional<int64_t> zero_point)127 Tensor qcat(
128 const c10::List<Tensor>& qxs,
129 int64_t dim,
130 std::optional<double> scale,
131 std::optional<int64_t> zero_point) {
132 TORCH_CHECK(is_valid_quantization_scheme(qxs[0]),
133 "Only per-tensor quantization is supported in 'cat'!")
134 double _scale = scale.has_value() ? scale.value() : qxs.get(0).q_scale();
135 int64_t _zero_point =
136 zero_point.has_value() ? zero_point.value() : qxs.get(0).q_zero_point();
137 return quantized_cat_impl<ReLUFused>(qxs, dim, _scale, _zero_point);
138 }
139
140 template <bool ReLUFused = false>
qcat_out(const c10::List<Tensor> & qxs,int64_t dim,Tensor out)141 Tensor qcat_out(const c10::List<Tensor>& qxs, int64_t dim, Tensor out) {
142 TORCH_CHECK(is_valid_quantization_scheme(qxs[0]),
143 "Only per-tensor quantization is supported in 'cat'!")
144 TORCH_CHECK(is_valid_quantization_scheme(out),
145 "Only per-tensor quantization is supported in 'cat'!")
146 auto out_ =
147 quantized_cat_impl<ReLUFused>(qxs, dim, out.q_scale(), out.q_zero_point());
148 at::native::copy_(out, out_, /*non_blocking=*/false);
149 return out;
150 }
151
152 } // namespace
153
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)154 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
155 m.impl(TORCH_SELECTIVE_NAME("quantized::cat"), TORCH_FN(qcat<false>));
156 m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu"), TORCH_FN(qcat<true>));
157 m.impl(TORCH_SELECTIVE_NAME("quantized::cat_out"), TORCH_FN(qcat_out<false>));
158 m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu_out"), TORCH_FN(qcat_out<true>));
159 }
160
cat_quantized_cpu(const ITensorListRef & qxs,int64_t dim)161 Tensor cat_quantized_cpu(const ITensorListRef& qxs, int64_t dim) {
162 auto materialized = qxs.materialize();
163 TORCH_CHECK(is_valid_quantization_scheme(materialized[0]),
164 "Only per-tensor quantization is supported in 'cat'!");
165
166 if (!all_inputs_sharing_qparams(materialized)) {
167 // TODO: if possible change this warning to an error T194501002
168 TORCH_WARN("All inputs of this cat operator must share the same quantization parameters. Otherwise large numerical inaccuracies may occur.");
169 }
170 check_cat_no_zero_dim(materialized);
171 dim = legacy_cat_wrap_dim(dim, materialized);
172 double _scale = materialized[0].get().q_scale();
173 int64_t _zero_point = materialized[0].get().q_zero_point();
174 return quantized_cat_impl<false>(materialized, dim, _scale, _zero_point);
175 }
176
cat_out_quantized_cpu(const ITensorListRef & qxs,int64_t dim,Tensor & out)177 Tensor& cat_out_quantized_cpu(const ITensorListRef& qxs, int64_t dim, Tensor& out) {
178 auto materialized = qxs.materialize();
179 TORCH_CHECK(is_valid_quantization_scheme(materialized[0]),
180 "Only per-tensor quantization is supported in 'cat'!")
181 TORCH_CHECK(is_valid_quantization_scheme(out),
182 "Only per-tensor quantization is supported in 'cat'!")
183 check_cat_no_zero_dim(materialized);
184 dim = legacy_cat_wrap_dim(dim, materialized);
185 auto out_ = quantized_cat_impl<false>(qxs, dim, out.q_scale(), out.q_zero_point());
186 at::native::copy_(out, out_, /*non_blocking=*/false);
187 return out;
188 }
189
190 } // namespace at::native
191