xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/Copy.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/native/quantized/AffineQuantizer.h>
3 #include <ATen/native/quantized/Copy.h>
4 #include <c10/core/MemoryFormat.h>
5 #include <c10/util/irange.h>
6 
7 
8 namespace at::native {
9 
10 // Copying from float to QInt, used for assigning float value to QTensor
11 // The second exception condition `self.is_contiguous() && src.is_contiguous()`
12 // forces both the self & src tensors to be contiguous.
13 // This means that assignment of a non-contiguous quantized subtensor is currently not supported in pytorch
14 // e.g., Consider a 2x2 quantized tensor qt1 and a non-quantized tensor t2. The operation
15 // `qt1[:, 0] = t2[:, 0]` would trigger the exception b/c neither the LHS nor RHS is contiguous
quantized_copy_from_float_(Tensor & self,const Tensor & src)16 Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src) {
17   TORCH_CHECK(
18       src.scalar_type() == at::kFloat,
19       "Quantized copy only works with kFloat as source Tensor");
20   TORCH_CHECK(
21       (self.is_contiguous() && src.is_contiguous()) ||
22       (self.is_contiguous(at::MemoryFormat::ChannelsLast) && src.is_contiguous(at::MemoryFormat::ChannelsLast)),
23       "Quantized copy only works with contiguous and NHWC Tensors");
24   TORCH_CHECK(
25       self.sizes().equals(src.sizes()),
26       "Quantized copy only works with Tensors with the same shape");
27   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "Copy", [&]() {
28     if (self.qscheme() == kPerChannelAffine || self.qscheme() == kPerChannelAffineFloatQParams
29         || self.qscheme() == kPerChannelSymmetric) {
30       quantize_tensor_per_channel_affine(src, self, self.q_per_channel_scales(),
31                                          self.q_per_channel_zero_points(),
32                                          self.q_per_channel_axis());
33     } else {
34       quantize_tensor_per_tensor_affine(src, self, self.q_scale(), self.q_zero_point());
35     }
36   });
37   return self;
38 }
39 } // namespace at::native
40