xref: /aosp_15_r20/external/pytorch/c10/core/QScheme.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <cstdint>
5 #include <string>
6 
7 namespace c10 {
8 
9 /**
10  * QScheme is an enum that specifies the type of quantization. This has a one
11  * to one correspondence with Quantizer
12  * Please refer to ATen/quantized/Quantizer.h to see the Quantizers classes.
13  * Keep this file in sync with torch/nn/_qscheme.py
14  */
15 enum class QScheme : uint8_t {
16   PER_TENSOR_AFFINE = 0,
17   PER_CHANNEL_AFFINE = 1,
18   PER_TENSOR_SYMMETRIC = 2,
19   PER_CHANNEL_SYMMETRIC = 3,
20   PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4,
21   COMPILE_TIME_NUM_QSCHEMES = 5,
22 };
23 
24 constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
25 constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
26 constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
27 constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
28 constexpr auto kPerChannelAffineFloatQParams =
29     QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
30 constexpr int COMPILE_TIME_NUM_QSCHEMES =
31     static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
32 
toString(QScheme qscheme)33 inline std::string toString(QScheme qscheme) {
34   switch (qscheme) {
35     case kPerTensorAffine:
36       return "per_tensor_affine";
37     case kPerChannelAffine:
38       return "per_channel_affine";
39     case kPerTensorSymmetric:
40       return "per_tensor_symmetric";
41     case kPerChannelSymmetric:
42       return "per_channel_symmetric";
43     case kPerChannelAffineFloatQParams:
44       return "per_channel_affine_float_qparams";
45     default:
46       TORCH_CHECK(false, "Unrecognized qscheme: ", static_cast<int>(qscheme));
47   }
48 }
49 
50 } // namespace c10
51