1 #include <ATen/ATen.h>
2 #include <ATen/Dispatch.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6 #include <ATen/native/quantized/FakeQuantAffine.h>
7
8 #include <c10/util/irange.h>
9
10 // FakeQuantize Op for PerChannelAffine quantization scheme.
11
12 namespace at::native {
13
14 // Use REGISTER_DISPATCH to run CPU and CUDA backend.
15 DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
16 DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);
17
18 /* Per channel fake-quantizes the 'inputs' tensor.
19 Args:
20 X: Forward input tensor.
21 dY: Backward input tensor (_backward op only).
22 scale: scale of per channel affine quantization
23 zero_point: zero_point of per channel affine quantization
24 axis: int specifying the axis to be quantized
25 quant_min: minimum quantized value
26 quant_max: maximum quantized value
27 Returns:
28 Fake quantized tensor (double dtype).
29
30 */
31
fake_quantize_per_channel_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max)32 Tensor fake_quantize_per_channel_affine(
33 const Tensor& self,
34 const Tensor& scale,
35 const Tensor& zero_point,
36 int64_t axis,
37 int64_t quant_min,
38 int64_t quant_max) {
39 const auto res = at::fake_quantize_per_channel_affine_cachemask(
40 self, scale, zero_point, axis, quant_min, quant_max);
41 return std::get<0>(res);
42 }
43
fake_quantize_per_channel_affine_cachemask(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max)44 std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
45 const Tensor& self,
46 const Tensor& scale,
47 const Tensor& zero_point,
48 int64_t axis,
49 int64_t quant_min,
50 int64_t quant_max) {
51 TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
52 "Scale must be Float, found ", scale.scalar_type());
53 TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
54 "Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
55 TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
56 TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
57 TORCH_CHECK(
58 scale.numel() == zero_point.numel(),
59 "scale and zero-point need to have the same dimensions");
60 TORCH_CHECK(
61 scale.numel() == self.size(axis),
62 "dimensions of scale and zero-point are not consistent with input tensor")
63
64 TORCH_CHECK(
65 quant_min <= quant_max,
66 "`quant_min` should be less than or \
67 equal to `quant_max`.");
68
69 if(!at::isFloatingType(zero_point.scalar_type())){
70 TORCH_CHECK(
71 at::min(zero_point).item().toInt() >= quant_min &&
72 at::max(zero_point).item().toInt() <= quant_max,
73 "`zero_point` must be between `quant_min` and `quant_max`.");
74 }
75 TORCH_CHECK(
76 axis >= 0 && axis <= self.dim(),
77 "`axis` must be between 0 and number of dimensions of input");
78
79 auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
80 auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
81
82 c10::DimVector expected_shape(self.dim(), 1);
83 expected_shape[axis] = self.size(axis);
84
85 TensorIterator iter = TensorIteratorConfig()
86 .check_all_same_dtype(false)
87 .add_output(Y)
88 .add_input(self)
89 .add_owned_input(native::_unsafe_view(scale, expected_shape))
90 .add_owned_input(native::_unsafe_view(zero_point, expected_shape))
91 .build();
92
93 // TODO(future, optional): read once, write twice. Not done at the moment
94 // for simplicity, as we do not expect this to be a bottleneck.
95 TensorIterator iter_mask = TensorIteratorConfig()
96 .check_all_same_dtype(false)
97 .add_output(mask)
98 .add_input(self)
99 .add_owned_input(native::_unsafe_view(scale, expected_shape))
100 .add_owned_input(native::_unsafe_view(zero_point, expected_shape))
101 .build();
102
103 // TODO(future, optional): look into packing the mask further (BoolTensor uses
104 // 1 byte per element, we only need 1 bit per element).
105 fake_quant_per_channel_cachemask_stub(iter.device_type(), iter, iter_mask, quant_min, quant_max);
106 return std::make_tuple(Y, mask);
107 }
108
109 /* Backward path to fake-quantize the 'inputs' tensor per channel, with mask.
110
111 Args:
112 dY: output grad.
113 mask: mask tensor from the forward pass.
114
115 Returns:
116 dX (input grad).
117 */
fake_quantize_per_channel_affine_cachemask_backward(const Tensor & dY,const Tensor & mask)118 Tensor fake_quantize_per_channel_affine_cachemask_backward(
119 const Tensor& dY,
120 const Tensor& mask) {
121 TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
122 TORCH_CHECK(mask.numel() == dY.numel(),
123 "`mask` and `dY` are not the same size: ",
124 "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
125 if (dY.numel() <= 0) {
126 return dY;
127 }
128 // Note: no additional kernels needed, since mask is pre-computed
129 // and we can use the existing tensor multiplication kernels.
130 return dY * mask;
131 }
132
_get_rounded_zero_point(const Tensor & zero_point,int64_t quant_min,int64_t quant_max)133 static Tensor _get_rounded_zero_point(
134 const Tensor& zero_point,
135 int64_t quant_min,
136 int64_t quant_max) {
137 // This assumes the per channel zero point vector is single-dimensioned.
138 return zero_point.round().clamp_(quant_min, quant_max);
139 }
140
_fake_quantize_learnable_per_channel_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,double grad_factor)141 Tensor _fake_quantize_learnable_per_channel_affine(
142 const Tensor& self,
143 const Tensor& scale,
144 const Tensor& zero_point,
145 int64_t axis,
146 int64_t quant_min,
147 int64_t quant_max,
148 double grad_factor) {
149 Tensor zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max).to(at::kInt);
150 return native::fake_quantize_per_channel_affine(
151 self, scale, zero_point_rounded, axis, quant_min, quant_max);
152 }
153
_fake_quantize_learnable_per_channel_affine_backward(const Tensor & dY,const Tensor & X,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,double grad_factor)154 std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_backward(
155 const Tensor& dY,
156 const Tensor& X,
157 const Tensor& scale,
158 const Tensor& zero_point,
159 int64_t axis,
160 int64_t quant_min,
161 int64_t quant_max,
162 double grad_factor) {
163 /* The gradients for scale and zero point are calculated as below:
164 Let Xfq be the fake quantized version of X.
165 Let Xq be the quantized version of X (clamped at qmin and qmax).
166 Let Delta and z be the scale and the zero point.
167 :math:
168 \frac{d\Delta }{dx} =
169 \begin{cases}
170 q_{\min} - z& \text{ if } X_q= q_{\min} \\
171 q_{\max} - z& \text{ if } X_q= q_{\max} \\
172 (X_{fq} - X) / \Delta & \text{ else }
173 \end{cases}
174
175 \frac{dz }{dx} =
176 \begin{cases}
177 -\Delta& \text{ if } X_q= q_{\min} \text{ or } X_q = q_{\max} \\
178 0 & \text{ else }
179 \end{cases}
180 */
181 auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
182
183 TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
184 TORCH_CHECK(X.scalar_type() == ScalarType::Float);
185 TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
186 TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
187
188 TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
189 TORCH_CHECK(
190 quant_min <= 0 && quant_max >= 0,
191 "Expecting `quant_min` <= 0 and `quant_max` >= 0");
192 TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
193 TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
194 TORCH_CHECK(
195 scale.numel() == zero_point.numel(),
196 "scale and zero-point need to have the same dimensions");
197 TORCH_CHECK(
198 scale.numel() == X.size(axis),
199 "dimensions of scale and zero-point are not consistent with input tensor")
200
201 TORCH_CHECK(
202 at::min(zero_point_rounded).item().toLong() >= quant_min &&
203 at::max(zero_point_rounded).item().toLong() <= quant_max,
204 "`zero_point` must be between `quant_min` and `quant_max`.");
205
206 TORCH_CHECK(
207 axis >= 0 && axis < X.dim(),
208 "`axis` must be between 0 and number of dimensions of input");
209
210 if (X.numel() <= 0) {
211 return std::make_tuple(X, scale, zero_point);
212 }
213
214 auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
215 auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
216 auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
217 auto numDimensions = X.ndimension();
218
219 // Create an axis mask for vectorizing and reshaping the scale and zero point tensors
220 // into the same shapes as X along the channel axis.
221 c10::DimVector axis_mask(numDimensions);
222 for (const auto i : c10::irange(numDimensions)) {
223 axis_mask[i] = (i == axis) ? X.size(axis) : 1;
224 }
225 auto X_shape = X.sizes();
226 auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
227 auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
228
229 auto iter = TensorIteratorConfig()
230 .add_output(dX)
231 .add_output(dScale_vec)
232 .add_output(dZeroPoint_vec)
233 .add_input(X)
234 .add_input(dY)
235 .add_input(scale_vectorized)
236 .add_input(zero_point_vectorized)
237 .build();
238
239 fake_quant_grad_learnable_channel_stub(
240 X.device().type(), iter, quant_min, quant_max, grad_factor);
241
242 auto numElements = X.ndimension() - 1;
243
244 // Create a collection of axes that include all but the channel axis for
245 // reduction when summing over the dScale and dZeroPoint tensors.
246 c10::DimVector axis_for_reduction(numElements);
247 for (const auto i : c10::irange(axis)) {
248 axis_for_reduction[i] = i;
249 }
250 for (const auto i : c10::irange(axis, numElements)) {
251 axis_for_reduction[i] = i + 1;
252 }
253
254 auto dScale = dScale_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
255 auto dZeroPoint = dZeroPoint_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
256
257 return std::make_tuple(dX, dScale, dZeroPoint);
258 }
259 } // namespace at::native
260