1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <optional>
21
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
31 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
32
33 namespace mlir {
34 namespace quant {
35
36 constexpr int k8Bits = 8;
37 constexpr int k32Bits = 32;
38 constexpr unsigned kSigned = QuantizationFlags::Signed;
39
DeviceTarget(MLIRContext * ctx)40 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
41 f32_ = FloatType::getF32(ctx_);
42 i8_ = IntegerType::get(ctx_, k8Bits);
43 i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
44 i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
45 i32_ = IntegerType::get(ctx_, k32Bits);
46 i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
47 i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
48 any_ = AnyQuantizedType();
49 qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
50 qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
51 qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
52 assert(qi8n_ == qi8n_);
53 }
54
GetKernelSpec(llvm::StringRef kernel,const KernelSpecs::Signature & signature) const55 Optional<KernelSpec> DeviceTarget::GetKernelSpec(
56 llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
57 auto kernel_specs_it = specs_.find(kernel);
58 if (kernel_specs_it == specs_.end()) return llvm::None;
59 return kernel_specs_it->getValue().Find(signature);
60 }
61
GetDecomposeFn(quantfork::QuantizeRegionOp op) const62 ScaleDecomposeFn DeviceTarget::GetDecomposeFn(
63 quantfork::QuantizeRegionOp op) const {
64 auto kernel_specs_it = specs_.find(op.getLogicalKernel());
65 if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
66 return kernel_specs_it->second.GetDecomposeFn();
67 }
68
AppendToSignature(Type spec,KernelSpecs::Signature * signature)69 void DeviceTarget::AppendToSignature(Type spec,
70 KernelSpecs::Signature* signature) {
71 if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
72 signature->push_back(AnyQuantizedType::get(
73 quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
74 quant.getStorageTypeMin(), quant.getStorageTypeMax()));
75 } else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
76 signature->push_back(any);
77 } else { // float
78 signature->push_back(AnyQuantizedType());
79 }
80 }
81
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleFn & fn,const ScaleDecomposeFn & dfn)82 LogicalResult DeviceTarget::RegisterKernel(
83 llvm::StringRef kernel, const KernelSpecs::Signature& signature,
84 const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
85 return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
86 }
87
88 namespace ph = std::placeholders;
89
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleConstraintType constraint)90 LogicalResult DeviceTarget::RegisterKernel(
91 llvm::StringRef kernel, const KernelSpecs::Signature& signature,
92 const ScaleConstraintType constraint) {
93 if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
94 switch (constraint) {
95 case ScaleConstraintType::OutputInputSameScale:
96 specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
97 ph::_1, ph::_2, ph::_3, ph::_4));
98 return success();
99 default:
100 return failure();
101 }
102 }
103
DecomposeMultiplyAccumulateScale(Operation * op,QuantizedMultipliers * input_multipliers,QuantizedMultipliers * output_multipliers,QuantizedRanges * output_ranges)104 LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
105 Operation* op, QuantizedMultipliers* input_multipliers,
106 QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges) {
107 auto rop = llvm::dyn_cast<quantfork::QuantizeRegionOp>(op);
108 if (!rop) return failure();
109
110 llvm::SmallVector<Type, 4> input_specs, out_specs;
111 for (auto spec : rop.getInputSpecs()) {
112 input_specs.push_back(spec.cast<TypeAttr>().getValue());
113 }
114 for (auto spec : rop.getOutputSpecs()) {
115 out_specs.push_back(spec.cast<TypeAttr>().getValue());
116 }
117
118 auto in_spec = input_specs[0].dyn_cast<UniformQuantizedType>();
119 // TODO(fengliuai): handles the PerAxis QuantizedType.
120 auto w_spec = input_specs[1].dyn_cast<UniformQuantizedType>();
121 auto b_spec = input_specs[2].dyn_cast<UniformQuantizedType>();
122 auto o_spec = out_specs[0].dyn_cast<UniformQuantizedType>();
123 if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
124
125 double scale_product = in_spec.getScale() * w_spec.getScale();
126 if (fabs(scale_product - b_spec.getScale()) >= 1e-6) return failure();
127
128 // input multipliers
129 input_multipliers->append(3, kUnitQuantizedMultiplier);
130
131 // output multipliers
132 double real_multiplier = scale_product / o_spec.getScale();
133 output_multipliers->push_back(QuantizeMultiplier(real_multiplier));
134
135 // output ranges
136 auto min = rop->getAttrOfType<FloatAttr>("min");
137 auto max = rop->getAttrOfType<FloatAttr>("max");
138 output_ranges->push_back(CalculateQuantizedRange(
139 o_spec.getScale(), o_spec.getZeroPoint(),
140 (min ? std::optional<double>(min.getValueAsDouble()) : std::nullopt),
141 (max ? std::optional<double>(max.getValueAsDouble()) : std::nullopt),
142 o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
143
144 return success();
145 }
146
DecomposeSameScale(Operation * op,QuantizedMultipliers * input_multipliers,QuantizedMultipliers * output_multipliers,QuantizedRanges * output_ranges)147 LogicalResult DeviceTarget::DecomposeSameScale(
148 Operation* op, QuantizedMultipliers* input_multipliers,
149 QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges) {
150 auto rop = llvm::dyn_cast<quantfork::QuantizeRegionOp>(op);
151 if (!rop) return failure();
152
153 // input multipliers
154 for (int i = 0; i < op->getNumOperands(); ++i) {
155 input_multipliers->push_back(kUnitQuantizedMultiplier);
156 }
157
158 // output multipliers
159 for (int i = 0; i < op->getNumResults(); ++i) {
160 output_multipliers->push_back(kUnitQuantizedMultiplier);
161 }
162
163 auto o_spec = rop.getOutputSpecs()[0]
164 .cast<TypeAttr>()
165 .getValue()
166 .dyn_cast<UniformQuantizedType>();
167 if (!o_spec) return failure();
168
169 // output ranges
170 auto min = rop->getAttrOfType<FloatAttr>("min");
171 auto max = rop->getAttrOfType<FloatAttr>("max");
172 output_ranges->push_back(CalculateQuantizedRange(
173 o_spec.getScale(), o_spec.getZeroPoint(),
174 (min ? std::optional<double>(min.getValueAsDouble()) : std::nullopt),
175 (max ? std::optional<double>(max.getValueAsDouble()) : std::nullopt),
176 o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
177
178 return success();
179 }
180
181 } // namespace quant
182 } // namespace mlir
183