1 #pragma once 2 3 #include <ATen/quantized/Quantizer.h> 4 #include <c10/core/TensorImpl.h> 5 #include <c10/util/Exception.h> 6 7 namespace at { 8 9 /** 10 * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which 11 * specifies the quantization scheme and parameters, for more information please 12 * see ATen/quantized/Quantizer.h 13 * 14 * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl. 15 */ 16 struct TORCH_API QTensorImpl : public c10::TensorImpl { 17 public: 18 QTensorImpl( 19 Storage&& storage, 20 DispatchKeySet key_set, 21 const caffe2::TypeMeta data_type, 22 QuantizerPtr quantizer); 23 24 // See Note [Enum ImplType] 25 QTensorImpl( 26 ImplType type, 27 Storage&& storage, 28 DispatchKeySet key_set, 29 const caffe2::TypeMeta data_type, 30 QuantizerPtr quantizer); 31 32 33 // TODO: Expose in PyTorch Frontend quantizerQTensorImpl34 QuantizerPtr quantizer() { 35 return quantizer_; 36 } 37 set_quantizer_QTensorImpl38 void set_quantizer_(QuantizerPtr quantizer) { 39 quantizer_ = quantizer; 40 } 41 42 /** 43 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 44 * 45 * For usage of `version_counter` and `allow_tensor_metadata_change`, 46 * see NOTE [ TensorImpl Shallow-Copying ]. 47 */ shallow_copy_and_detachQTensorImpl48 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 49 const c10::VariableVersion& version_counter, 50 bool allow_tensor_metadata_change) const override { 51 auto impl = c10::make_intrusive<QTensorImpl>( 52 Storage(storage()), key_set(), data_type_, quantizer_); 53 copy_tensor_metadata( 54 /*src_impl=*/this, 55 /*dest_impl=*/impl.get(), 56 /*version_counter=*/version_counter, 57 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 58 impl->refresh_numel(); 59 impl->refresh_contiguous(); 60 return impl; 61 } 62 63 /** 64 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 65 * 66 * For usage of `version_counter` and `allow_tensor_metadata_change`, 67 * see NOTE [ TensorImpl Shallow-Copying ]. 68 */ shallow_copy_and_detachQTensorImpl69 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 70 c10::VariableVersion&& version_counter, 71 bool allow_tensor_metadata_change) const override { 72 auto impl = c10::make_intrusive<QTensorImpl>( 73 Storage(storage()), key_set(), data_type_, quantizer_); 74 copy_tensor_metadata( 75 /*src_impl=*/this, 76 /*dest_impl=*/impl.get(), 77 /*version_counter=*/std::move(version_counter), 78 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 79 impl->refresh_numel(); 80 impl->refresh_contiguous(); 81 return impl; 82 } 83 84 /** 85 * Shallow-copies data from another TensorImpl into this TensorImpl. 86 * 87 * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`, 88 * see NOTE [ TensorImpl Shallow-Copying ]. 89 */ shallow_copy_fromQTensorImpl90 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { 91 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); 92 auto q_impl = static_cast<const QTensorImpl*>(impl.get()); 93 copy_tensor_metadata( 94 /*src_impl=*/q_impl, 95 /*dest_impl=*/this, 96 /*version_counter=*/version_counter(), 97 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); 98 refresh_numel(); 99 refresh_contiguous(); 100 } 101 102 private: 103 QuantizerPtr quantizer_; 104 105 const char* tensorimpl_type_name() const override; 106 107 /** 108 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) 109 * from one TensorImpl to another TensorImpl. 110 * 111 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. 112 */ copy_tensor_metadataQTensorImpl113 static void copy_tensor_metadata( 114 const QTensorImpl* src_q_impl, 115 QTensorImpl* dest_q_impl, 116 const c10::VariableVersion& version_counter, 117 bool allow_tensor_metadata_change) { 118 TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change); 119 120 // OpaqueTensorImpl-specific fields. 121 dest_q_impl->quantizer_ = src_q_impl->quantizer_; 122 } 123 }; 124 125 } // namespace at 126