xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/device_target.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
18 
19 #include <functional>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/Types.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
36 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
37 
38 namespace mlir {
39 namespace quant {
40 
41 class QuantizeContext;
42 
43 using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
44 using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>;
45 using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>;
46 using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
47                                             AdjacentOperations*, bool*)>;
48 
49 using ScaleDecomposeFn =
50     std::function<LogicalResult(Operation*, QuantizedMultipliers*,
51                                 QuantizedMultipliers*, QuantizedRanges*)>;
52 
53 static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0};
54 
55 enum class ScaleConstraintType {
56   OutputInputSameScale,
57   OutputInputFreeScale,
58   CustomScale,
59 };
60 
61 // Each kernel signature has its own specification for scales.
62 struct KernelSpec {
63   // Scale constraint
64   ScaleConstraintType type;
65 
66   // Custom function to derive the scales. Only available when the scale
67   // constraint is `CustomScale`.
68   ScaleFn scale_fn;
69 };
70 
71 class KernelSpecs {
72  public:
73   using Signature = llvm::SmallVector<quant::AnyQuantizedType, 4>;
74 
75   // Returns the kernel specification for the kernel signature.
Find(const Signature & signature)76   Optional<KernelSpec> Find(const Signature& signature) const {
77     auto spec_it = all_signatures_.find(signature);
78     if (spec_it != all_signatures_.end()) {
79       return spec_it->second;
80     } else {
81       return llvm::None;
82     }
83   }
84 
GetDecomposeFn()85   ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; }
86 
87   // Adds the kernel signature with the kernel specification.
Add(const Signature & signature,const KernelSpec & spec)88   LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
89     if (all_signatures_.insert({signature, spec}).second) return success();
90     return failure();
91   }
92 
WithSignature(const KernelSpecs::Signature & signature,const ScaleFn & fn)93   KernelSpecs& WithSignature(const KernelSpecs::Signature& signature,
94                              const ScaleFn& fn) {
95     (void)Add(signature, {ScaleConstraintType::CustomScale, fn});
96     return *this;
97   }
98 
WithImpl(const ScaleDecomposeFn & dfn)99   KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) {
100     decompose_fn_ = dfn;
101     return *this;
102   }
103 
104  private:
105   // The signature is pattern match based.
106   struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
getEmptyKeySignatureInfo107     static inline Signature getEmptyKey() { return {}; }
getTombstoneKeySignatureInfo108     static inline Signature getTombstoneKey() { return {nullptr}; }
getHashValueSignatureInfo109     static unsigned getHashValue(Signature val) {
110       return llvm::hash_combine_range(val.begin(), val.end());
111     }
isEqualSignatureInfo112     static bool isEqual(Signature LHS, Signature RHS) {
113       if (RHS == getEmptyKey()) return LHS == getEmptyKey();
114       if (RHS == getTombstoneKey()) return LHS == getTombstoneKey();
115       if (LHS.size() != RHS.size()) return false;
116       for (auto arg : llvm::zip(LHS, RHS)) {
117         if (std::get<0>(arg) != std::get<1>(arg)) return false;
118       }
119       return true;
120     }
121   };
122 
123   // Maps the signature to the kernel spec. Note that the matching is
124   // pattern match based.
125   llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
126 
127   // A method to compute the effective multipliers. This is independent on the
128   // bits of the ports, thus all the signature shares the same here.
129   ScaleDecomposeFn decompose_fn_;
130 };
131 
132 class DeviceTarget {
133  public:
134   explicit DeviceTarget(MLIRContext* ctx);
135 
136   // Retrieves the kernel spec for the quant region op.
137   Optional<KernelSpec> GetKernelSpec(
138       llvm::StringRef kernel, const KernelSpecs::Signature& signature) const;
139 
140   // Retrieves the scale decomposition function for the quant region op.
141   ScaleDecomposeFn GetDecomposeFn(quantfork::QuantizeRegionOp op) const;
142 
143   // converts specification to signature:
144   // - UniformedQuantizedType -> AnyQuantizedType
145   // - AnyQuantizedType (int) -> AnyQuantizedType
146   // - Float -> {}
147   static void AppendToSignature(Type spec, KernelSpecs::Signature* signature);
148 
149  protected:
150   // Adds the kernel spec with the custom scale function for the kernel.
151   LogicalResult RegisterKernel(llvm::StringRef kernel,
152                                const KernelSpecs::Signature& signature,
153                                const ScaleFn& fn, const ScaleDecomposeFn& dfn);
154 
155   // Adds the kernel spec with the scale constraint type for the kernel.
156   LogicalResult RegisterKernel(llvm::StringRef kernel,
157                                const KernelSpecs::Signature& signature,
158                                const ScaleConstraintType constraint);
159 
160   // Adds the kernel with the name. Retrun an existing one if it has been
161   // added before.
RegisterKernel(llvm::StringRef kernel)162   KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
163 
164   // For "mulmat->add" type of kernels, convert the scales of all the ports to
165   // multipliers.
166   static LogicalResult DecomposeMultiplyAccumulateScale(
167       Operation* op, QuantizedMultipliers* input_multipliers,
168       QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges);
169 
170   // For "reshape" type of kernels.
171   static LogicalResult DecomposeSameScale(
172       Operation* op, QuantizedMultipliers* input_multipliers,
173       QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges);
174 
175   // A set of parameters are required to build the signatures.
176   FloatType f32_;
177   IntegerType i8_, i32_;
178   int64_t i8_min_, i8_max_, i32_min_, i32_max_;
179   quant::AnyQuantizedType any_, qi8_, qi8n_, qi32_;
180 
181  private:
182   // Maps the kernel names to all the available kernels.
183   llvm::StringMap<KernelSpecs> specs_;
184 
185   // Points to the global MLIRContext.
186   MLIRContext* ctx_;
187 };
188 
189 }  // namespace quant
190 }  // namespace mlir
191 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
192