xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/TensorFactories.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/quantized/Quantizer.h>
3 #include <c10/core/QScheme.h>
4 #include <c10/core/TensorOptions.h>
5 
6 #include <utility>
7 
8 
9 namespace at::native {
10 
11 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12 // We explicitly pass in scale and zero_point because we don't have the infra
13 // ready to support quantizer in python frontend, once that is ready, we'll
14 // change to use quantizer
empty_affine_quantized(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,double scale,int64_t zero_point,std::optional<c10::MemoryFormat> optional_memory_format)15 Tensor empty_affine_quantized(
16     IntArrayRef size,
17     std::optional<ScalarType> dtype,
18     std::optional<Layout> layout,
19     std::optional<Device> device,
20     std::optional<bool> pin_memory,
21     double scale,
22     int64_t zero_point,
23     std::optional<c10::MemoryFormat> optional_memory_format) {
24   // See [Note: hacky wrapper removal for TensorOptions]
25   TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
26 
27   TORCH_CHECK(
28     !(options_.has_memory_format() && optional_memory_format.has_value()),
29     "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
30     "the redundant setter.");
31   auto options = options_.merge_memory_format(optional_memory_format);
32   TORCH_CHECK(
33       options.has_dtype(),
34       "Must provide data type for Tensor creation functions.");
35   return new_qtensor(
36       size,
37       options,
38       make_per_tensor_affine_quantizer(
39           scale, zero_point, typeMetaToScalarType(options.dtype())));
40 }
41 
empty_per_channel_affine_quantized(IntArrayRef size,const Tensor & scales,const Tensor & zero_points,int64_t axis,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)42 Tensor empty_per_channel_affine_quantized(
43     IntArrayRef size,
44     const Tensor& scales,
45     const Tensor& zero_points,
46     int64_t axis,
47     std::optional<ScalarType> dtype,
48     std::optional<Layout> layout,
49     std::optional<Device> device,
50     std::optional<bool> pin_memory,
51     std::optional<c10::MemoryFormat> optional_memory_format) {
52   // See [Note: hacky wrapper removal for TensorOptions]
53   TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
54 
55   TORCH_CHECK(
56     !(options_.has_memory_format() && optional_memory_format.has_value()),
57     "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
58     "the redundant setter.");
59   auto options = options_.merge_memory_format(optional_memory_format);
60   TORCH_CHECK(
61       options.has_dtype(),
62       "Must provide data type for Tensor creation functions.");
63   QuantizerPtr quantizer = make_per_channel_affine_quantizer(
64           scales.to(options.device()), zero_points.to(options.device()), axis, typeMetaToScalarType(options.dtype()));
65   return new_qtensor(
66       size,
67       options,
68       std::move(quantizer));
69 }
70 
empty_unknown_quantized(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)71 Tensor empty_unknown_quantized(
72     IntArrayRef size,
73     std::optional<ScalarType> dtype,
74     std::optional<Layout> layout,
75     std::optional<Device> device,
76     std::optional<bool> pin_memory,
77     std::optional<c10::MemoryFormat> optional_memory_format) {
78   // See [Note: hacky wrapper removal for TensorOptions]
79   TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
80 
81   TORCH_CHECK(
82     !(options_.has_memory_format() && optional_memory_format.has_value()),
83     "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
84     "the redundant setter.");
85   auto options = options_.merge_memory_format(optional_memory_format);
86   TORCH_CHECK(
87       options.has_dtype(),
88       "Must provide data type for Tensor creation functions.");
89   QuantizerPtr quantizer = make_unknown_quantizer(typeMetaToScalarType(options.dtype()));
90   return new_qtensor(size, options, std::move(quantizer));
91 }
92 
empty_strided_unknown_quantized(IntArrayRef size,IntArrayRef strided,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)93 Tensor empty_strided_unknown_quantized(
94     IntArrayRef size,
95     IntArrayRef strided,
96     std::optional<ScalarType> dtype,
97     std::optional<Layout> layout,
98     std::optional<Device> device,
99     std::optional<bool> pin_memory) {
100 
101   TORCH_CHECK(false, "empty_strided not supported on quantized tensors yet see https://github.com/pytorch/pytorch/issues/74540")
102 
103 }
104 
105 // Provide better error message if dtype is wrong
empty_affine_quantized_other_backends_stub(IntArrayRef,std::optional<ScalarType>,std::optional<Layout>,std::optional<Device>,std::optional<bool>,double,int64_t,std::optional<c10::MemoryFormat>)106 Tensor empty_affine_quantized_other_backends_stub(
107     IntArrayRef,
108     std::optional<ScalarType>,
109     std::optional<Layout>,
110     std::optional<Device>,
111     std::optional<bool>,
112     double,
113     int64_t,
114     std::optional<c10::MemoryFormat>) {
115   TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
116 }
117 
empty_per_channel_affine_quantized_other_backends_stub(IntArrayRef,const Tensor &,const Tensor &,int64_t,std::optional<ScalarType>,std::optional<Layout>,std::optional<Device>,std::optional<bool>,std::optional<c10::MemoryFormat>)118 Tensor empty_per_channel_affine_quantized_other_backends_stub(
119     IntArrayRef,
120     const Tensor&,
121     const Tensor&,
122     int64_t,
123     std::optional<ScalarType>,
124     std::optional<Layout>,
125     std::optional<Device>,
126     std::optional<bool>,
127     std::optional<c10::MemoryFormat>) {
128   TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
129 }
130 
131 // Create an empty quantized Tensor with size, based on the options
132 // and quantization parameters of the input quantized Tensor
empty_quantized(IntArrayRef size,const Tensor & qtensor,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> memory_format)133 Tensor empty_quantized(
134     IntArrayRef size,
135     const Tensor& qtensor,
136     std::optional<ScalarType> dtype,
137     std::optional<Layout> layout,
138     std::optional<Device> device,
139     std::optional<bool> pin_memory,
140     std::optional<c10::MemoryFormat> memory_format) {
141   TensorOptions specified_options =
142       TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
143 
144   TORCH_CHECK(
145       !(specified_options.has_memory_format() && memory_format.has_value()),
146       "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
147       "the redundant setter.");
148 
149   TensorOptions options = qtensor.options()
150                               .merge_in(specified_options)
151                               .merge_memory_format(memory_format);
152 
153   Tensor output;
154   if (qtensor.qscheme() == kPerTensorAffine) {
155     output = at::_empty_affine_quantized(
156         size, options, qtensor.q_scale(), qtensor.q_zero_point());
157   } else if (
158       qtensor.qscheme() == kPerChannelAffine ||
159       qtensor.qscheme() == kPerChannelAffineFloatQParams) {
160     output = at::_empty_per_channel_affine_quantized(
161         size,
162         qtensor.q_per_channel_scales(),
163         qtensor.q_per_channel_zero_points(),
164         qtensor.q_per_channel_axis(),
165         options);
166   } else {
167     TORCH_CHECK(
168         false,
169         "QScheme not supported by empty_quantized:",
170         toString(qtensor.qscheme()));
171   }
172   return output;
173 }
174 
175 } // namespace at::native
176