xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/device_target.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <optional>
21 
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
31 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
32 
33 namespace mlir {
34 namespace quant {
35 
36 constexpr int k8Bits = 8;
37 constexpr int k32Bits = 32;
38 constexpr unsigned kSigned = QuantizationFlags::Signed;
39 
DeviceTarget(MLIRContext * ctx)40 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
41   f32_ = FloatType::getF32(ctx_);
42   i8_ = IntegerType::get(ctx_, k8Bits);
43   i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
44   i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
45   i32_ = IntegerType::get(ctx_, k32Bits);
46   i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
47   i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
48   any_ = AnyQuantizedType();
49   qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
50   qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
51   qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
52   assert(qi8n_ == qi8n_);
53 }
54 
GetKernelSpec(llvm::StringRef kernel,const KernelSpecs::Signature & signature) const55 Optional<KernelSpec> DeviceTarget::GetKernelSpec(
56     llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
57   auto kernel_specs_it = specs_.find(kernel);
58   if (kernel_specs_it == specs_.end()) return llvm::None;
59   return kernel_specs_it->getValue().Find(signature);
60 }
61 
GetDecomposeFn(quantfork::QuantizeRegionOp op) const62 ScaleDecomposeFn DeviceTarget::GetDecomposeFn(
63     quantfork::QuantizeRegionOp op) const {
64   auto kernel_specs_it = specs_.find(op.getLogicalKernel());
65   if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
66   return kernel_specs_it->second.GetDecomposeFn();
67 }
68 
AppendToSignature(Type spec,KernelSpecs::Signature * signature)69 void DeviceTarget::AppendToSignature(Type spec,
70                                      KernelSpecs::Signature* signature) {
71   if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
72     signature->push_back(AnyQuantizedType::get(
73         quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
74         quant.getStorageTypeMin(), quant.getStorageTypeMax()));
75   } else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
76     signature->push_back(any);
77   } else {  // float
78     signature->push_back(AnyQuantizedType());
79   }
80 }
81 
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleFn & fn,const ScaleDecomposeFn & dfn)82 LogicalResult DeviceTarget::RegisterKernel(
83     llvm::StringRef kernel, const KernelSpecs::Signature& signature,
84     const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
85   return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
86 }
87 
88 namespace ph = std::placeholders;
89 
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleConstraintType constraint)90 LogicalResult DeviceTarget::RegisterKernel(
91     llvm::StringRef kernel, const KernelSpecs::Signature& signature,
92     const ScaleConstraintType constraint) {
93   if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
94   switch (constraint) {
95     case ScaleConstraintType::OutputInputSameScale:
96       specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
97                                         ph::_1, ph::_2, ph::_3, ph::_4));
98       return success();
99     default:
100       return failure();
101   }
102 }
103 
DecomposeMultiplyAccumulateScale(Operation * op,QuantizedMultipliers * input_multipliers,QuantizedMultipliers * output_multipliers,QuantizedRanges * output_ranges)104 LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
105     Operation* op, QuantizedMultipliers* input_multipliers,
106     QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges) {
107   auto rop = llvm::dyn_cast<quantfork::QuantizeRegionOp>(op);
108   if (!rop) return failure();
109 
110   llvm::SmallVector<Type, 4> input_specs, out_specs;
111   for (auto spec : rop.getInputSpecs()) {
112     input_specs.push_back(spec.cast<TypeAttr>().getValue());
113   }
114   for (auto spec : rop.getOutputSpecs()) {
115     out_specs.push_back(spec.cast<TypeAttr>().getValue());
116   }
117 
118   auto in_spec = input_specs[0].dyn_cast<UniformQuantizedType>();
119   // TODO(fengliuai): handles the PerAxis QuantizedType.
120   auto w_spec = input_specs[1].dyn_cast<UniformQuantizedType>();
121   auto b_spec = input_specs[2].dyn_cast<UniformQuantizedType>();
122   auto o_spec = out_specs[0].dyn_cast<UniformQuantizedType>();
123   if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
124 
125   double scale_product = in_spec.getScale() * w_spec.getScale();
126   if (fabs(scale_product - b_spec.getScale()) >= 1e-6) return failure();
127 
128   // input multipliers
129   input_multipliers->append(3, kUnitQuantizedMultiplier);
130 
131   // output multipliers
132   double real_multiplier = scale_product / o_spec.getScale();
133   output_multipliers->push_back(QuantizeMultiplier(real_multiplier));
134 
135   // output ranges
136   auto min = rop->getAttrOfType<FloatAttr>("min");
137   auto max = rop->getAttrOfType<FloatAttr>("max");
138   output_ranges->push_back(CalculateQuantizedRange(
139       o_spec.getScale(), o_spec.getZeroPoint(),
140       (min ? std::optional<double>(min.getValueAsDouble()) : std::nullopt),
141       (max ? std::optional<double>(max.getValueAsDouble()) : std::nullopt),
142       o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
143 
144   return success();
145 }
146 
DecomposeSameScale(Operation * op,QuantizedMultipliers * input_multipliers,QuantizedMultipliers * output_multipliers,QuantizedRanges * output_ranges)147 LogicalResult DeviceTarget::DecomposeSameScale(
148     Operation* op, QuantizedMultipliers* input_multipliers,
149     QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges) {
150   auto rop = llvm::dyn_cast<quantfork::QuantizeRegionOp>(op);
151   if (!rop) return failure();
152 
153   // input multipliers
154   for (int i = 0; i < op->getNumOperands(); ++i) {
155     input_multipliers->push_back(kUnitQuantizedMultiplier);
156   }
157 
158   // output multipliers
159   for (int i = 0; i < op->getNumResults(); ++i) {
160     output_multipliers->push_back(kUnitQuantizedMultiplier);
161   }
162 
163   auto o_spec = rop.getOutputSpecs()[0]
164                     .cast<TypeAttr>()
165                     .getValue()
166                     .dyn_cast<UniformQuantizedType>();
167   if (!o_spec) return failure();
168 
169   // output ranges
170   auto min = rop->getAttrOfType<FloatAttr>("min");
171   auto max = rop->getAttrOfType<FloatAttr>("max");
172   output_ranges->push_back(CalculateQuantizedRange(
173       o_spec.getScale(), o_spec.getZeroPoint(),
174       (min ? std::optional<double>(min.getValueAsDouble()) : std::nullopt),
175       (max ? std::optional<double>(max.getValueAsDouble()) : std::nullopt),
176       o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
177 
178   return success();
179 }
180 
181 }  // namespace quant
182 }  // namespace mlir
183