1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h"
16
17 #include <memory>
18
19 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
22
23 namespace mlir {
24 namespace quant {
25
HasQuantizedTensors(Operation * op)26 bool HasQuantizedTensors(Operation* op) {
27 if (IsOpNotQuantizable(op)) return false;
28 for (Type operand_type : op->getOperandTypes()) {
29 auto tensor_type = operand_type.dyn_cast<TensorType>();
30 if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
31 return true;
32 }
33 }
34 for (Type result_type : op->getResultTypes()) {
35 auto tensor_type = result_type.dyn_cast<TensorType>();
36 if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
37 return true;
38 }
39 }
40 return false;
41 }
42
HasStaticShape(Value value)43 bool HasStaticShape(Value value) {
44 auto shaped_type = value.getType().dyn_cast<ShapedType>();
45 if (!shaped_type) return false;
46
47 return shaped_type.hasStaticShape();
48 }
49
HasStaticShapeAtDims(Value value,llvm::ArrayRef<int> dims)50 bool HasStaticShapeAtDims(Value value, llvm::ArrayRef<int> dims) {
51 auto shaped_type = value.getType().dyn_cast<ShapedType>();
52 if (!shaped_type) return false;
53
54 for (auto dim : dims) {
55 if (shaped_type.isDynamicDim(dim)) return false;
56 }
57 return true;
58 }
59
CloneTypeWithNewElementType(Type old_type,Type element_type)60 Type CloneTypeWithNewElementType(Type old_type, Type element_type) {
61 if (!old_type.isa<ShapedType>()) return {};
62
63 return old_type.cast<ShapedType>().clone(element_type);
64 }
65
66 // These constant folding utilities are forked from
67 // tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc.
68 // TODO(b/241488936): Remove these constant folding utility functions after
69 // adding a new constant folding pass to TensorFlow.
IsOperationFoldable(Operation * op)70 LogicalResult IsOperationFoldable(Operation* op) {
71 if (!op->getDialect()->getNamespace().equals("tf") ||
72 llvm::isa<TF::ConstOp>(op)) {
73 return failure();
74 }
75 // Ops with `NoConstantFold` trait or side effects should not be constant
76 // folded to preserve the original semantics.
77 if (op->hasTrait<OpTrait::IsTerminator>() ||
78 op->hasTrait<OpTrait::TF::NoConstantFold>() || op->getNumRegions() != 0 ||
79 !MemoryEffectOpInterface::hasNoEffect(op)) {
80 return failure();
81 }
82
83 // If any of the result types are variants, don't try to constant fold them.
84 // This creates opaque variant constants which lose information and would
85 // require "raising" later.
86 for (auto type : op->getResultTypes()) {
87 if (auto tensor_type = type.dyn_cast<TensorType>()) {
88 if (tensor_type.getElementType().isa<TF::VariantType>()) {
89 return failure();
90 }
91 }
92 }
93
94 // Do not execute function calls.
95 if (llvm::isa<TF::WhileOp, TF::CaseOp, TF::IfOp, CallOpInterface>(op)) {
96 return failure();
97 }
98
99 // Check if the operands are constants or foldable as well.
100 for (auto operand : op->getOperands()) {
101 auto preceding_op = operand.getDefiningOp();
102 if (!preceding_op || (!llvm::isa<TF::ConstOp>(preceding_op) &&
103 failed(IsOperationFoldable(preceding_op)))) {
104 return failure();
105 }
106 }
107
108 return success();
109 }
110
111 // Folds the operation recursively and return the results.
FoldOperation(TFE_Context * ctx,OpBuilder & builder,Operation * op,llvm::SmallVector<Value> & results)112 LogicalResult FoldOperation(TFE_Context* ctx, OpBuilder& builder, Operation* op,
113 llvm::SmallVector<Value>& results) {
114 results.clear();
115 builder.setInsertionPointAfter(op);
116
117 bool has_empty_numerical_results =
118 llvm::all_of(op->getResultTypes(), [](Type ty) {
119 ShapedType shaped_ty = ty.cast<ShapedType>();
120 Type element_ty = shaped_ty.getElementType();
121 return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 &&
122 element_ty.isIntOrFloat();
123 });
124 if (has_empty_numerical_results && op->isRegistered()) {
125 for (Type ty : op->getResultTypes()) {
126 auto shaped_ty = ty.cast<ShapedType>();
127 results.push_back(builder.create<TF::ConstOp>(
128 op->getLoc(),
129 DenseElementsAttr::get(shaped_ty, llvm::ArrayRef<Attribute>())));
130 }
131 return success();
132 }
133
134 SmallVector<ElementsAttr, 4> inputs;
135 for (auto operand : op->getOperands()) {
136 auto preceding_const_op = operand.getDefiningOp<TF::ConstOp>();
137 if (preceding_const_op) {
138 inputs.push_back(preceding_const_op.value());
139 continue;
140 }
141
142 Operation* preceding_op = operand.getDefiningOp();
143 int preceding_result_id = -1;
144 for (auto preceding_result : preceding_op->getResults()) {
145 if (operand == preceding_result) {
146 preceding_result_id = preceding_result.getResultNumber();
147 break;
148 }
149 }
150 llvm::SmallVector<Value> preceding_results;
151 if (failed(FoldOperation(ctx, builder, preceding_op, preceding_results))) {
152 return failure();
153 }
154 auto preceding_result = preceding_results[preceding_result_id];
155 preceding_const_op = preceding_result.getDefiningOp<TF::ConstOp>();
156 inputs.push_back(preceding_const_op.value());
157 }
158
159 // Avoid overlapping folds with the same context.
160 static auto* mu = new tensorflow::mutex();
161 tensorflow::mutex_lock l(*mu);
162 SmallVector<Attribute, 8> constants;
163 if (failed(tensorflow::EvaluateOperation(op, inputs, ctx, &constants))) {
164 return failure();
165 }
166 for (const auto& constant : constants) {
167 results.push_back(builder.create<TF::ConstOp>(op->getLoc(), constant));
168 }
169 return success();
170 }
171
InitializeTFRuntime()172 TFE_Context* InitializeTFRuntime() {
173 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
174 TF_NewStatus(), TF_DeleteStatus);
175 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
176 TFE_NewContextOptions(), TFE_DeleteContextOptions);
177 // Only initialize single CPU.
178 tensorflow::ConfigProto config_proto;
179 // This is conceptually equal to what we do in python/eager/context.py but
180 // with all GPU devices ignored and CPU only set to 1.
181 (*config_proto.mutable_device_count())["CPU"] = 1;
182 (*config_proto.mutable_device_count())["GPU"] = 0;
183 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
184 TF_NewBuffer(), TF_DeleteBuffer);
185 DCHECK(config->data == nullptr);
186
187 // Copy config_proto into config.
188 {
189 const size_t proto_size = config_proto.ByteSizeLong();
190 void* buf = tensorflow::port::Malloc(proto_size);
191 if (buf == nullptr) {
192 LOG(ERROR) << "Failed to allocate memory to serialize ConfigProto "
193 "while creating context options for constant folding";
194 return nullptr;
195 }
196 if (!config_proto.SerializeWithCachedSizesToArray(
197 static_cast<uint8_t*>(buf))) {
198 tensorflow::port::Free(buf);
199 LOG(ERROR) << "Unable to serialize ConfigProto while creating context "
200 "options for constant folding";
201 return nullptr;
202 }
203 config->data = buf;
204 config->length = proto_size;
205 config->data_deallocator = [](void* data, size_t length) {
206 tensorflow::port::Free(data);
207 };
208 }
209
210 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
211 status.get());
212 if (TF_GetCode(status.get()) != TF_OK) {
213 LOG(ERROR) << "Failed to set context options for constant folding: "
214 << status.get();
215 return nullptr;
216 }
217
218 // Input tensors are placed on the host CPU so use the explicit device
219 // policy to fail if no CPU kernels are available for the op.
220 TFE_ContextOptionsSetDevicePlacementPolicy(opts.get(),
221 TFE_DEVICE_PLACEMENT_EXPLICIT);
222 auto ctx = TFE_NewContext(opts.get(), status.get());
223 if (TF_GetCode(status.get()) != TF_OK) {
224 LOG(ERROR) << "Failed to create context for constant folding: "
225 << status.get();
226 return nullptr;
227 }
228 return ctx;
229 }
230
ConstantFoldOpIfPossible(Operation * op)231 llvm::SmallVector<Value> ConstantFoldOpIfPossible(Operation* op) {
232 if (failed(IsOperationFoldable(op))) return op->getResults();
233
234 static TFE_Context* ctx = InitializeTFRuntime();
235 if (!ctx) return op->getResults();
236
237 OpBuilder builder(op);
238 llvm::SmallVector<Value> results;
239 if (failed(FoldOperation(ctx, builder, op, results))) {
240 return op->getResults();
241 }
242 return results;
243 }
244
245 } // namespace quant
246 } // namespace mlir
247