xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/Dispatch.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6 #include <ATen/native/quantized/FakeQuantAffine.h>
7 
8 #include <c10/util/irange.h>
9 
10 // FakeQuantize Op for PerChannelAffine quantization scheme.
11 
12 namespace at::native {
13 
14 // Use REGISTER_DISPATCH to run CPU and CUDA backend.
15 DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
16 DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);
17 
18 /* Per channel fake-quantizes the 'inputs' tensor.
19 Args:
20   X: Forward input tensor.
21   dY: Backward input tensor (_backward op only).
22   scale: scale of per channel affine quantization
23   zero_point: zero_point of per channel affine quantization
24   axis: int specifying the axis to be quantized
25   quant_min: minimum quantized value
26   quant_max: maximum quantized value
27 Returns:
28   Fake quantized tensor (double dtype).
29 
30 */
31 
fake_quantize_per_channel_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max)32 Tensor fake_quantize_per_channel_affine(
33     const Tensor& self,
34     const Tensor& scale,
35     const Tensor& zero_point,
36     int64_t axis,
37     int64_t quant_min,
38     int64_t quant_max) {
39   const auto res = at::fake_quantize_per_channel_affine_cachemask(
40       self, scale, zero_point, axis, quant_min, quant_max);
41   return std::get<0>(res);
42 }
43 
fake_quantize_per_channel_affine_cachemask(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max)44 std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
45     const Tensor& self,
46     const Tensor& scale,
47     const Tensor& zero_point,
48     int64_t axis,
49     int64_t quant_min,
50     int64_t quant_max) {
51   TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
52               "Scale must be Float, found ", scale.scalar_type());
53   TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
54               "Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
55   TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
56   TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
57   TORCH_CHECK(
58       scale.numel() == zero_point.numel(),
59       "scale and zero-point need to have the same dimensions");
60   TORCH_CHECK(
61       scale.numel() == self.size(axis),
62       "dimensions of scale and zero-point are not consistent with input tensor")
63 
64   TORCH_CHECK(
65       quant_min <= quant_max,
66       "`quant_min` should be less than or \
67         equal to `quant_max`.");
68 
69   if(!at::isFloatingType(zero_point.scalar_type())){
70       TORCH_CHECK(
71           at::min(zero_point).item().toInt() >= quant_min &&
72               at::max(zero_point).item().toInt() <= quant_max,
73           "`zero_point` must be between `quant_min` and `quant_max`.");
74   }
75   TORCH_CHECK(
76       axis >= 0 && axis <= self.dim(),
77       "`axis` must be between 0 and number of dimensions of input");
78 
79   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
80   auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
81 
82   c10::DimVector expected_shape(self.dim(), 1);
83   expected_shape[axis] = self.size(axis);
84 
85   TensorIterator iter = TensorIteratorConfig()
86     .check_all_same_dtype(false)
87     .add_output(Y)
88     .add_input(self)
89     .add_owned_input(native::_unsafe_view(scale, expected_shape))
90     .add_owned_input(native::_unsafe_view(zero_point, expected_shape))
91     .build();
92 
93   // TODO(future, optional): read once, write twice.  Not done at the moment
94   //   for simplicity, as we do not expect this to be a bottleneck.
95   TensorIterator iter_mask = TensorIteratorConfig()
96     .check_all_same_dtype(false)
97     .add_output(mask)
98     .add_input(self)
99     .add_owned_input(native::_unsafe_view(scale, expected_shape))
100     .add_owned_input(native::_unsafe_view(zero_point, expected_shape))
101     .build();
102 
103   // TODO(future, optional): look into packing the mask further (BoolTensor uses
104   //   1 byte per element, we only need 1 bit per element).
105   fake_quant_per_channel_cachemask_stub(iter.device_type(), iter, iter_mask, quant_min, quant_max);
106   return std::make_tuple(Y, mask);
107 }
108 
109 /* Backward path to fake-quantize the 'inputs' tensor per channel, with mask.
110 
111 Args:
112   dY: output grad.
113   mask: mask tensor from the forward pass.
114 
115 Returns:
116   dX (input grad).
117 */
fake_quantize_per_channel_affine_cachemask_backward(const Tensor & dY,const Tensor & mask)118 Tensor fake_quantize_per_channel_affine_cachemask_backward(
119     const Tensor& dY,
120     const Tensor& mask) {
121   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
122   TORCH_CHECK(mask.numel() == dY.numel(),
123       "`mask` and `dY` are not the same size: ",
124       "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
125   if (dY.numel() <= 0) {
126     return dY;
127   }
128   // Note: no additional kernels needed, since mask is pre-computed
129   // and we can use the existing tensor multiplication kernels.
130   return dY * mask;
131 }
132 
_get_rounded_zero_point(const Tensor & zero_point,int64_t quant_min,int64_t quant_max)133 static Tensor _get_rounded_zero_point(
134     const Tensor& zero_point,
135     int64_t quant_min,
136     int64_t quant_max) {
137   // This assumes the per channel zero point vector is single-dimensioned.
138   return zero_point.round().clamp_(quant_min, quant_max);
139 }
140 
_fake_quantize_learnable_per_channel_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,double grad_factor)141 Tensor _fake_quantize_learnable_per_channel_affine(
142     const Tensor& self,
143     const Tensor& scale,
144     const Tensor& zero_point,
145     int64_t axis,
146     int64_t quant_min,
147     int64_t quant_max,
148     double grad_factor) {
149   Tensor zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max).to(at::kInt);
150   return native::fake_quantize_per_channel_affine(
151     self, scale, zero_point_rounded, axis, quant_min, quant_max);
152 }
153 
_fake_quantize_learnable_per_channel_affine_backward(const Tensor & dY,const Tensor & X,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,double grad_factor)154 std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_backward(
155     const Tensor& dY,
156     const Tensor& X,
157     const Tensor& scale,
158     const Tensor& zero_point,
159     int64_t axis,
160     int64_t quant_min,
161     int64_t quant_max,
162     double grad_factor) {
163   /* The gradients for scale and zero point are calculated as below:
164      Let Xfq be the fake quantized version of X.
165      Let Xq be the quantized version of X (clamped at qmin and qmax).
166      Let Delta and z be the scale and the zero point.
167      :math:
168       \frac{d\Delta }{dx} =
169         \begin{cases}
170           q_{\min} - z& \text{ if } X_q= q_{\min} \\
171           q_{\max} - z& \text{ if } X_q= q_{\max} \\
172           (X_{fq} - X) / \Delta & \text{ else }
173         \end{cases}
174 
175       \frac{dz }{dx} =
176         \begin{cases}
177           -\Delta& \text{ if } X_q= q_{\min} \text{ or } X_q = q_{\max} \\
178           0 & \text{ else }
179         \end{cases}
180   */
181   auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
182 
183   TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
184   TORCH_CHECK(X.scalar_type() == ScalarType::Float);
185   TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
186   TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
187 
188   TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
189   TORCH_CHECK(
190       quant_min <= 0 && quant_max >= 0,
191       "Expecting `quant_min` <= 0 and `quant_max` >= 0");
192   TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
193   TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
194   TORCH_CHECK(
195       scale.numel() == zero_point.numel(),
196       "scale and zero-point need to have the same dimensions");
197   TORCH_CHECK(
198       scale.numel() == X.size(axis),
199       "dimensions of scale and zero-point are not consistent with input tensor")
200 
201   TORCH_CHECK(
202       at::min(zero_point_rounded).item().toLong() >= quant_min &&
203           at::max(zero_point_rounded).item().toLong() <= quant_max,
204       "`zero_point` must be between `quant_min` and `quant_max`.");
205 
206   TORCH_CHECK(
207       axis >= 0 && axis < X.dim(),
208       "`axis` must be between 0 and number of dimensions of input");
209 
210   if (X.numel() <= 0) {
211     return std::make_tuple(X, scale, zero_point);
212   }
213 
214   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
215   auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
216   auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
217   auto numDimensions = X.ndimension();
218 
219   // Create an axis mask for vectorizing and reshaping the scale and zero point tensors
220   // into the same shapes as X along the channel axis.
221   c10::DimVector axis_mask(numDimensions);
222   for (const auto i : c10::irange(numDimensions)) {
223     axis_mask[i] = (i == axis) ? X.size(axis) : 1;
224   }
225   auto X_shape = X.sizes();
226   auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
227   auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
228 
229   auto iter = TensorIteratorConfig()
230     .add_output(dX)
231     .add_output(dScale_vec)
232     .add_output(dZeroPoint_vec)
233     .add_input(X)
234     .add_input(dY)
235     .add_input(scale_vectorized)
236     .add_input(zero_point_vectorized)
237     .build();
238 
239   fake_quant_grad_learnable_channel_stub(
240     X.device().type(), iter, quant_min, quant_max, grad_factor);
241 
242   auto numElements = X.ndimension() - 1;
243 
244   // Create a collection of axes that include all but the channel axis for
245   // reduction when summing over the dScale and dZeroPoint tensors.
246   c10::DimVector axis_for_reduction(numElements);
247   for (const auto i : c10::irange(axis)) {
248     axis_for_reduction[i] = i;
249   }
250   for (const auto i : c10::irange(axis, numElements)) {
251     axis_for_reduction[i] = i + 1;
252   }
253 
254   auto dScale = dScale_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
255   auto dZeroPoint = dZeroPoint_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
256 
257   return std::make_tuple(dX, dScale, dZeroPoint);
258 }
259 } // namespace at::native
260