1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <type_traits>
16
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/Matchers.h" // from @llvm-project
21 #include "mlir/IR/OpDefinition.h" // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
23 #include "mlir/IR/Value.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26
27 #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
28 #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
29 namespace mlir {
30 namespace quant {
31
32 // TODO(b/238829558): Populate quantization config based on the
33 // QuantizationOptions proto. We might want to clean QuantizationMethod as well
34 // as this can be inferred from the proto.
35 using OpSet = tensorflow::quantization::OpSet;
36
37 enum class QuantizationMethod {
38 kQuantizationAwareTraining,
39 kPostTrainingQuantization,
40 kDynamicRangeQuantization
41 };
42
43 // Returns true if the value has static shape.
44 bool HasStaticShape(Value value);
45
46 // Returns true if the value has static shape at given dims.
47 bool HasStaticShapeAtDims(Value value, llvm::ArrayRef<int> dims);
48
49 // Returns true if the op has any quantized tensors as input or output.
50 bool HasQuantizedTensors(Operation *op);
51
52 // Creates a new type that has the shape from the `old_type` and the element
53 // type from the `element_type`.
54 Type CloneTypeWithNewElementType(Type old_type, Type element_type);
55
56 // Creates an array with integer/float type.
57 template <typename T>
CreateConstValue(OpBuilder & builder,Location loc,const llvm::SmallVector<int64_t> & shape,const llvm::SmallVector<T> & values)58 Value CreateConstValue(OpBuilder &builder, Location loc,
59 const llvm::SmallVector<int64_t> &shape,
60 const llvm::SmallVector<T> &values) {
61 static_assert(std::is_integral_v<T> || std::is_same_v<T, float>);
62 if (std::is_integral_v<T>) {
63 auto shape_type =
64 RankedTensorType::get(shape, builder.getIntegerType(sizeof(T) * 8));
65
66 DenseIntElementsAttr attr = DenseIntElementsAttr::get(shape_type, values);
67 return builder.create<TF::ConstOp>(loc, attr);
68 }
69
70 auto type = RankedTensorType::get(shape, builder.getF32Type());
71 auto value_attr = DenseFPElementsAttr::get(type, values);
72 return builder.create<TF::ConstOp>(loc, value_attr);
73 }
74
75 // Creates a 1D array with integer/float type.
76 template <typename T>
Create1DConstValue(OpBuilder & builder,Location loc,const llvm::SmallVector<T> & values)77 Value Create1DConstValue(OpBuilder &builder, Location loc,
78 const llvm::SmallVector<T> &values) {
79 return CreateConstValue<T>(builder, loc,
80 {static_cast<int64_t>(values.size())}, values);
81 }
82
83 // Creates a scalar with integer/float type.
84 template <typename T>
CreateScalarConstValue(OpBuilder & builder,Location loc,T value)85 Value CreateScalarConstValue(OpBuilder &builder, Location loc, T value) {
86 return CreateConstValue<T>(builder, loc, {}, {value});
87 }
88
89 // Checks if the value is a constant and return its splat value.
90 template <typename T>
GetSplatValue(Value value,T & splat_value)91 bool GetSplatValue(Value value, T &splat_value) {
92 static_assert(std::is_integral_v<T> || std::is_same_v<T, float>);
93 if (std::is_integral_v<T>) {
94 DenseIntElementsAttr value_attr;
95 if (!matchPattern(value, m_Constant(&value_attr)) ||
96 !value_attr.isSplat()) {
97 return false;
98 }
99 splat_value = value_attr.getSplatValue<T>();
100 return true;
101 }
102
103 DenseFPElementsAttr value_attr;
104 if (!matchPattern(value, m_Constant(&value_attr)) || !value_attr.isSplat()) {
105 return false;
106 }
107 splat_value = value_attr.getSplatValue<T>();
108
109 return true;
110 }
111
112 // Checks if the value is a constant and its splat value is equal to x.
113 template <typename T>
IsSplatValueEqual(Value value,T x)114 bool IsSplatValueEqual(Value value, T x) {
115 T splat_value;
116 if (!GetSplatValue(value, splat_value)) return false;
117
118 return splat_value == x;
119 }
120
121 // Checks if two values are constants and their splat values are equal.
122 template <typename T>
AreSplatValuesEqual(Value x,Value y)123 bool AreSplatValuesEqual(Value x, Value y) {
124 T splat_x, splat_y;
125 if (!GetSplatValue(x, splat_x) || !GetSplatValue(y, splat_y)) {
126 return false;
127 }
128
129 return splat_x == splat_y;
130 }
131
132 // TODO(b/241488936): Remove this function after adding a new constant folding
133 // pass to TensorFlow.
134 // Applies constant folding to the operation if possible and return the folded
135 // results.
136 llvm::SmallVector<Value> ConstantFoldOpIfPossible(Operation *op);
137
138 } // namespace quant
139 } // namespace mlir
140 #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
141