1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ 17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ 18 19 #include <functional> 20 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/Hashing.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/ADT/StringMap.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/ErrorHandling.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project 30 #include "mlir/IR/Attributes.h" // from @llvm-project 31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 32 #include "mlir/IR/MLIRContext.h" // from @llvm-project 33 #include "mlir/IR/Types.h" // from @llvm-project 34 #include "mlir/Support/LogicalResult.h" // from @llvm-project 35 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" 36 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" 37 38 namespace mlir { 39 namespace quant { 40 41 class QuantizeContext; 42 43 using AdjacentOperations = llvm::SmallVectorImpl<Operation*>; 44 using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>; 45 using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>; 46 using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*, 47 AdjacentOperations*, bool*)>; 48 49 using ScaleDecomposeFn = 50 std::function<LogicalResult(Operation*, QuantizedMultipliers*, 51 QuantizedMultipliers*, QuantizedRanges*)>; 52 53 static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0}; 54 55 enum class ScaleConstraintType { 56 OutputInputSameScale, 57 OutputInputFreeScale, 58 CustomScale, 59 }; 60 61 // Each kernel signature has its own specification for scales. 62 struct KernelSpec { 63 // Scale constraint 64 ScaleConstraintType type; 65 66 // Custom function to derive the scales. Only available when the scale 67 // constraint is `CustomScale`. 68 ScaleFn scale_fn; 69 }; 70 71 class KernelSpecs { 72 public: 73 using Signature = llvm::SmallVector<quant::AnyQuantizedType, 4>; 74 75 // Returns the kernel specification for the kernel signature. Find(const Signature & signature)76 Optional<KernelSpec> Find(const Signature& signature) const { 77 auto spec_it = all_signatures_.find(signature); 78 if (spec_it != all_signatures_.end()) { 79 return spec_it->second; 80 } else { 81 return llvm::None; 82 } 83 } 84 GetDecomposeFn()85 ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; } 86 87 // Adds the kernel signature with the kernel specification. Add(const Signature & signature,const KernelSpec & spec)88 LogicalResult Add(const Signature& signature, const KernelSpec& spec) { 89 if (all_signatures_.insert({signature, spec}).second) return success(); 90 return failure(); 91 } 92 WithSignature(const KernelSpecs::Signature & signature,const ScaleFn & fn)93 KernelSpecs& WithSignature(const KernelSpecs::Signature& signature, 94 const ScaleFn& fn) { 95 (void)Add(signature, {ScaleConstraintType::CustomScale, fn}); 96 return *this; 97 } 98 WithImpl(const ScaleDecomposeFn & dfn)99 KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) { 100 decompose_fn_ = dfn; 101 return *this; 102 } 103 104 private: 105 // The signature is pattern match based. 106 struct SignatureInfo : public llvm::DenseMapInfo<Signature> { getEmptyKeySignatureInfo107 static inline Signature getEmptyKey() { return {}; } getTombstoneKeySignatureInfo108 static inline Signature getTombstoneKey() { return {nullptr}; } getHashValueSignatureInfo109 static unsigned getHashValue(Signature val) { 110 return llvm::hash_combine_range(val.begin(), val.end()); 111 } isEqualSignatureInfo112 static bool isEqual(Signature LHS, Signature RHS) { 113 if (RHS == getEmptyKey()) return LHS == getEmptyKey(); 114 if (RHS == getTombstoneKey()) return LHS == getTombstoneKey(); 115 if (LHS.size() != RHS.size()) return false; 116 for (auto arg : llvm::zip(LHS, RHS)) { 117 if (std::get<0>(arg) != std::get<1>(arg)) return false; 118 } 119 return true; 120 } 121 }; 122 123 // Maps the signature to the kernel spec. Note that the matching is 124 // pattern match based. 125 llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_; 126 127 // A method to compute the effective multipliers. This is independent on the 128 // bits of the ports, thus all the signature shares the same here. 129 ScaleDecomposeFn decompose_fn_; 130 }; 131 132 class DeviceTarget { 133 public: 134 explicit DeviceTarget(MLIRContext* ctx); 135 136 // Retrieves the kernel spec for the quant region op. 137 Optional<KernelSpec> GetKernelSpec( 138 llvm::StringRef kernel, const KernelSpecs::Signature& signature) const; 139 140 // Retrieves the scale decomposition function for the quant region op. 141 ScaleDecomposeFn GetDecomposeFn(quantfork::QuantizeRegionOp op) const; 142 143 // converts specification to signature: 144 // - UniformedQuantizedType -> AnyQuantizedType 145 // - AnyQuantizedType (int) -> AnyQuantizedType 146 // - Float -> {} 147 static void AppendToSignature(Type spec, KernelSpecs::Signature* signature); 148 149 protected: 150 // Adds the kernel spec with the custom scale function for the kernel. 151 LogicalResult RegisterKernel(llvm::StringRef kernel, 152 const KernelSpecs::Signature& signature, 153 const ScaleFn& fn, const ScaleDecomposeFn& dfn); 154 155 // Adds the kernel spec with the scale constraint type for the kernel. 156 LogicalResult RegisterKernel(llvm::StringRef kernel, 157 const KernelSpecs::Signature& signature, 158 const ScaleConstraintType constraint); 159 160 // Adds the kernel with the name. Retrun an existing one if it has been 161 // added before. RegisterKernel(llvm::StringRef kernel)162 KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; } 163 164 // For "mulmat->add" type of kernels, convert the scales of all the ports to 165 // multipliers. 166 static LogicalResult DecomposeMultiplyAccumulateScale( 167 Operation* op, QuantizedMultipliers* input_multipliers, 168 QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges); 169 170 // For "reshape" type of kernels. 171 static LogicalResult DecomposeSameScale( 172 Operation* op, QuantizedMultipliers* input_multipliers, 173 QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges); 174 175 // A set of parameters are required to build the signatures. 176 FloatType f32_; 177 IntegerType i8_, i32_; 178 int64_t i8_min_, i8_max_, i32_min_, i32_max_; 179 quant::AnyQuantizedType any_, qi8_, qi8n_, qi32_; 180 181 private: 182 // Maps the kernel names to all the available kernels. 183 llvm::StringMap<KernelSpecs> specs_; 184 185 // Points to the global MLIRContext. 186 MLIRContext* ctx_; 187 }; 188 189 } // namespace quant 190 } // namespace mlir 191 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ 192