xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <type_traits>
16 
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/Matchers.h"  // from @llvm-project
21 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
23 #include "mlir/IR/Value.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 
27 #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
28 #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
29 namespace mlir {
30 namespace quant {
31 
32 // TODO(b/238829558): Populate quantization config based on the
33 // QuantizationOptions proto. We might want to clean QuantizationMethod as well
34 // as this can be inferred from the proto.
35 using OpSet = tensorflow::quantization::OpSet;
36 
37 enum class QuantizationMethod {
38   kQuantizationAwareTraining,
39   kPostTrainingQuantization,
40   kDynamicRangeQuantization
41 };
42 
43 // Returns true if the value has static shape.
44 bool HasStaticShape(Value value);
45 
46 // Returns true if the value has static shape at given dims.
47 bool HasStaticShapeAtDims(Value value, llvm::ArrayRef<int> dims);
48 
49 // Returns true if the op has any quantized tensors as input or output.
50 bool HasQuantizedTensors(Operation *op);
51 
52 // Creates a new type that has the shape from the `old_type` and the element
53 // type from the `element_type`.
54 Type CloneTypeWithNewElementType(Type old_type, Type element_type);
55 
56 // Creates an array with integer/float type.
57 template <typename T>
CreateConstValue(OpBuilder & builder,Location loc,const llvm::SmallVector<int64_t> & shape,const llvm::SmallVector<T> & values)58 Value CreateConstValue(OpBuilder &builder, Location loc,
59                        const llvm::SmallVector<int64_t> &shape,
60                        const llvm::SmallVector<T> &values) {
61   static_assert(std::is_integral_v<T> || std::is_same_v<T, float>);
62   if (std::is_integral_v<T>) {
63     auto shape_type =
64         RankedTensorType::get(shape, builder.getIntegerType(sizeof(T) * 8));
65 
66     DenseIntElementsAttr attr = DenseIntElementsAttr::get(shape_type, values);
67     return builder.create<TF::ConstOp>(loc, attr);
68   }
69 
70   auto type = RankedTensorType::get(shape, builder.getF32Type());
71   auto value_attr = DenseFPElementsAttr::get(type, values);
72   return builder.create<TF::ConstOp>(loc, value_attr);
73 }
74 
75 // Creates a 1D array with integer/float type.
76 template <typename T>
Create1DConstValue(OpBuilder & builder,Location loc,const llvm::SmallVector<T> & values)77 Value Create1DConstValue(OpBuilder &builder, Location loc,
78                          const llvm::SmallVector<T> &values) {
79   return CreateConstValue<T>(builder, loc,
80                              {static_cast<int64_t>(values.size())}, values);
81 }
82 
83 // Creates a scalar with integer/float type.
84 template <typename T>
CreateScalarConstValue(OpBuilder & builder,Location loc,T value)85 Value CreateScalarConstValue(OpBuilder &builder, Location loc, T value) {
86   return CreateConstValue<T>(builder, loc, {}, {value});
87 }
88 
89 // Checks if the value is a constant and return its splat value.
90 template <typename T>
GetSplatValue(Value value,T & splat_value)91 bool GetSplatValue(Value value, T &splat_value) {
92   static_assert(std::is_integral_v<T> || std::is_same_v<T, float>);
93   if (std::is_integral_v<T>) {
94     DenseIntElementsAttr value_attr;
95     if (!matchPattern(value, m_Constant(&value_attr)) ||
96         !value_attr.isSplat()) {
97       return false;
98     }
99     splat_value = value_attr.getSplatValue<T>();
100     return true;
101   }
102 
103   DenseFPElementsAttr value_attr;
104   if (!matchPattern(value, m_Constant(&value_attr)) || !value_attr.isSplat()) {
105     return false;
106   }
107   splat_value = value_attr.getSplatValue<T>();
108 
109   return true;
110 }
111 
112 // Checks if the value is a constant and its splat value is equal to x.
113 template <typename T>
IsSplatValueEqual(Value value,T x)114 bool IsSplatValueEqual(Value value, T x) {
115   T splat_value;
116   if (!GetSplatValue(value, splat_value)) return false;
117 
118   return splat_value == x;
119 }
120 
121 // Checks if two values are constants and their splat values are equal.
122 template <typename T>
AreSplatValuesEqual(Value x,Value y)123 bool AreSplatValuesEqual(Value x, Value y) {
124   T splat_x, splat_y;
125   if (!GetSplatValue(x, splat_x) || !GetSplatValue(y, splat_y)) {
126     return false;
127   }
128 
129   return splat_x == splat_y;
130 }
131 
132 // TODO(b/241488936): Remove this function after adding a new constant folding
133 // pass to TensorFlow.
134 // Applies constant folding to the operation if possible and return the folded
135 // results.
136 llvm::SmallVector<Value> ConstantFoldOpIfPossible(Operation *op);
137 
138 }  // namespace quant
139 }  // namespace mlir
140 #endif  // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTIL_H_
141