1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/Dispatch.h> 5 #include <ATen/native/DispatchStub.h> 6 #include <ATen/native/quantized/AffineQuantizerBase.h> 7 8 namespace at { 9 namespace native { 10 11 Tensor& quantize_tensor_per_tensor_affine( 12 const Tensor& rtensor, 13 Tensor& qtensor, 14 double scale, 15 int64_t zero_point); 16 Tensor& quantize_tensor_per_channel_affine( 17 const Tensor& rtensor, 18 Tensor& qtensor, 19 const Tensor& scales, 20 Tensor zero_points, 21 int64_t axis); 22 23 Tensor& quantize_tensor_per_channel_float_qparams( 24 const Tensor& rtensor, 25 Tensor& qtensor, 26 const Tensor& scales, 27 const Tensor& zero_points, 28 int64_t axis); 29 30 Tensor& dequantize_tensor_per_tensor_affine( 31 const Tensor& qtensor, 32 Tensor& rtensor, 33 double scale, 34 int64_t zero_point); 35 Tensor& dequantize_tensor_per_channel_affine( 36 const Tensor& qtensor, 37 Tensor& rtensor, 38 const Tensor& scales, 39 Tensor zero_points, 40 int64_t axis); 41 Tensor& dequantize_tensor_per_channel_float_qparams( 42 const Tensor& qtensor, 43 Tensor& rtensor, 44 const Tensor& scales, 45 const Tensor& zero_points, 46 int64_t axis); 47 48 using quantize_tensor_per_tensor_affine_fn = 49 void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point); 50 51 using quantize_tensor_per_channel_affine_fn = void (*)( 52 const Tensor& rtensor, 53 Tensor& qtensor, 54 const Tensor& scales, 55 const Tensor& zero_points, 56 int64_t axis); 57 58 using quantize_tensor_per_channel_float_qparams_fn = void (*)( 59 const Tensor& rtensor, 60 Tensor& qtensor, 61 const Tensor& scales, 62 const Tensor& zero_points, 63 int64_t axis); 64 65 using dequantize_tensor_per_tensor_affine_fn = 66 void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point); 67 68 using dequantize_tensor_per_channel_affine_fn = void (*)( 69 const Tensor& qtensor, 70 Tensor& rtensor, 71 const Tensor& scales, 72 const Tensor& zero_points, 73 int64_t axis); 74 75 using dequantize_tensor_per_channel_float_qparams_fn = void (*)( 76 const Tensor& qtensor, 77 Tensor& rtensor, 78 const Tensor& scales, 79 const Tensor& zero_points, 80 int64_t axis); 81 82 using quantize_tensor_per_tensor_affine_sub_byte_fn = 83 void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point); 84 85 using dequantize_tensor_per_tensor_affine_sub_byte_fn = 86 void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point); 87 88 DECLARE_DISPATCH( 89 quantize_tensor_per_tensor_affine_fn, 90 quantize_tensor_per_tensor_affine_stub); 91 DECLARE_DISPATCH( 92 quantize_tensor_per_channel_affine_fn, 93 quantize_tensor_per_channel_affine_stub); 94 DECLARE_DISPATCH( 95 quantize_tensor_per_channel_float_qparams_fn, 96 quantize_tensor_per_channel_float_qparams_stub); 97 98 DECLARE_DISPATCH( 99 dequantize_tensor_per_tensor_affine_fn, 100 dequantize_tensor_per_tensor_affine_stub); 101 DECLARE_DISPATCH( 102 dequantize_tensor_per_channel_affine_fn, 103 dequantize_tensor_per_channel_affine_stub); 104 DECLARE_DISPATCH( 105 dequantize_tensor_per_channel_float_qparams_fn, 106 dequantize_tensor_per_channel_float_qparams_stub); 107 108 DECLARE_DISPATCH( 109 quantize_tensor_per_tensor_affine_sub_byte_fn, 110 quantize_tensor_per_tensor_affine_sub_byte_stub); 111 112 DECLARE_DISPATCH( 113 dequantize_tensor_per_tensor_affine_sub_byte_fn, 114 dequantize_tensor_per_tensor_affine_sub_byte_stub); 115 116 template <typename T> 117 TORCH_API Tensor quantize_tensor( 118 Tensor rtensor, 119 Tensor qtensor, 120 double scale, 121 int64_t zero_point); 122 template <typename T> 123 TORCH_API Tensor dequantize_tensor( 124 Tensor qtensor, 125 Tensor rtensor, 126 double scale, 127 int64_t zero_point); 128 129 } // namespace native 130 } // namespace at 131