xref: /aosp_15_r20/external/pytorch/aten/src/ATen/quantized/QTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/quantized/Quantizer.h>
4 #include <c10/core/TensorImpl.h>
5 #include <c10/util/Exception.h>
6 
7 namespace at {
8 
9 /**
10  * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which
11  * specifies the quantization scheme and parameters, for more information please
12  * see ATen/quantized/Quantizer.h
13  *
14  * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl.
15  */
16 struct TORCH_API QTensorImpl : public c10::TensorImpl {
17  public:
18   QTensorImpl(
19       Storage&& storage,
20       DispatchKeySet key_set,
21       const caffe2::TypeMeta data_type,
22       QuantizerPtr quantizer);
23 
24   // See Note [Enum ImplType]
25   QTensorImpl(
26       ImplType type,
27       Storage&& storage,
28       DispatchKeySet key_set,
29       const caffe2::TypeMeta data_type,
30       QuantizerPtr quantizer);
31 
32 
33   // TODO: Expose in PyTorch Frontend
quantizerQTensorImpl34   QuantizerPtr quantizer() {
35     return quantizer_;
36   }
37 
set_quantizer_QTensorImpl38   void set_quantizer_(QuantizerPtr quantizer) {
39     quantizer_ = quantizer;
40   }
41 
42   /**
43    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
44    *
45    * For usage of `version_counter` and `allow_tensor_metadata_change`,
46    * see NOTE [ TensorImpl Shallow-Copying ].
47    */
shallow_copy_and_detachQTensorImpl48   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
49       const c10::VariableVersion& version_counter,
50       bool allow_tensor_metadata_change) const override {
51     auto impl = c10::make_intrusive<QTensorImpl>(
52         Storage(storage()), key_set(), data_type_, quantizer_);
53     copy_tensor_metadata(
54       /*src_impl=*/this,
55       /*dest_impl=*/impl.get(),
56       /*version_counter=*/version_counter,
57       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
58     impl->refresh_numel();
59     impl->refresh_contiguous();
60     return impl;
61   }
62 
63   /**
64    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
65    *
66    * For usage of `version_counter` and `allow_tensor_metadata_change`,
67    * see NOTE [ TensorImpl Shallow-Copying ].
68    */
shallow_copy_and_detachQTensorImpl69   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
70       c10::VariableVersion&& version_counter,
71       bool allow_tensor_metadata_change) const override {
72     auto impl = c10::make_intrusive<QTensorImpl>(
73         Storage(storage()), key_set(), data_type_, quantizer_);
74     copy_tensor_metadata(
75       /*src_impl=*/this,
76       /*dest_impl=*/impl.get(),
77       /*version_counter=*/std::move(version_counter),
78       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
79     impl->refresh_numel();
80     impl->refresh_contiguous();
81     return impl;
82   }
83 
84   /**
85    * Shallow-copies data from another TensorImpl into this TensorImpl.
86    *
87    * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
88    * see NOTE [ TensorImpl Shallow-Copying ].
89    */
shallow_copy_fromQTensorImpl90   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
91     AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
92     auto q_impl = static_cast<const QTensorImpl*>(impl.get());
93     copy_tensor_metadata(
94       /*src_impl=*/q_impl,
95       /*dest_impl=*/this,
96       /*version_counter=*/version_counter(),
97       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
98     refresh_numel();
99     refresh_contiguous();
100   }
101 
102  private:
103   QuantizerPtr quantizer_;
104 
105   const char* tensorimpl_type_name() const override;
106 
107   /**
108    * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset)
109    * from one TensorImpl to another TensorImpl.
110    *
111    * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
112    */
copy_tensor_metadataQTensorImpl113   static void copy_tensor_metadata(
114       const QTensorImpl* src_q_impl,
115       QTensorImpl* dest_q_impl,
116       const c10::VariableVersion& version_counter,
117       bool allow_tensor_metadata_change) {
118     TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change);
119 
120     // OpaqueTensorImpl-specific fields.
121     dest_q_impl->quantizer_ = src_q_impl->quantizer_;
122   }
123 };
124 
125 } // namespace at
126