xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
22 
23 namespace mlir {
24 namespace quant {
25 
HasQuantizedTensors(Operation * op)26 bool HasQuantizedTensors(Operation* op) {
27   if (IsOpNotQuantizable(op)) return false;
28   for (Type operand_type : op->getOperandTypes()) {
29     auto tensor_type = operand_type.dyn_cast<TensorType>();
30     if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
31       return true;
32     }
33   }
34   for (Type result_type : op->getResultTypes()) {
35     auto tensor_type = result_type.dyn_cast<TensorType>();
36     if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
37       return true;
38     }
39   }
40   return false;
41 }
42 
HasStaticShape(Value value)43 bool HasStaticShape(Value value) {
44   auto shaped_type = value.getType().dyn_cast<ShapedType>();
45   if (!shaped_type) return false;
46 
47   return shaped_type.hasStaticShape();
48 }
49 
HasStaticShapeAtDims(Value value,llvm::ArrayRef<int> dims)50 bool HasStaticShapeAtDims(Value value, llvm::ArrayRef<int> dims) {
51   auto shaped_type = value.getType().dyn_cast<ShapedType>();
52   if (!shaped_type) return false;
53 
54   for (auto dim : dims) {
55     if (shaped_type.isDynamicDim(dim)) return false;
56   }
57   return true;
58 }
59 
CloneTypeWithNewElementType(Type old_type,Type element_type)60 Type CloneTypeWithNewElementType(Type old_type, Type element_type) {
61   if (!old_type.isa<ShapedType>()) return {};
62 
63   return old_type.cast<ShapedType>().clone(element_type);
64 }
65 
66 // These constant folding utilities are forked from
67 // tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc.
68 // TODO(b/241488936): Remove these constant folding utility functions after
69 // adding a new constant folding pass to TensorFlow.
IsOperationFoldable(Operation * op)70 LogicalResult IsOperationFoldable(Operation* op) {
71   if (!op->getDialect()->getNamespace().equals("tf") ||
72       llvm::isa<TF::ConstOp>(op)) {
73     return failure();
74   }
75   // Ops with `NoConstantFold` trait or side effects should not be constant
76   // folded to preserve the original semantics.
77   if (op->hasTrait<OpTrait::IsTerminator>() ||
78       op->hasTrait<OpTrait::TF::NoConstantFold>() || op->getNumRegions() != 0 ||
79       !MemoryEffectOpInterface::hasNoEffect(op)) {
80     return failure();
81   }
82 
83   // If any of the result types are variants, don't try to constant fold them.
84   // This creates opaque variant constants which lose information and would
85   // require "raising" later.
86   for (auto type : op->getResultTypes()) {
87     if (auto tensor_type = type.dyn_cast<TensorType>()) {
88       if (tensor_type.getElementType().isa<TF::VariantType>()) {
89         return failure();
90       }
91     }
92   }
93 
94   // Do not execute function calls.
95   if (llvm::isa<TF::WhileOp, TF::CaseOp, TF::IfOp, CallOpInterface>(op)) {
96     return failure();
97   }
98 
99   // Check if the operands are constants or foldable as well.
100   for (auto operand : op->getOperands()) {
101     auto preceding_op = operand.getDefiningOp();
102     if (!preceding_op || (!llvm::isa<TF::ConstOp>(preceding_op) &&
103                           failed(IsOperationFoldable(preceding_op)))) {
104       return failure();
105     }
106   }
107 
108   return success();
109 }
110 
111 // Folds the operation recursively and return the results.
FoldOperation(TFE_Context * ctx,OpBuilder & builder,Operation * op,llvm::SmallVector<Value> & results)112 LogicalResult FoldOperation(TFE_Context* ctx, OpBuilder& builder, Operation* op,
113                             llvm::SmallVector<Value>& results) {
114   results.clear();
115   builder.setInsertionPointAfter(op);
116 
117   bool has_empty_numerical_results =
118       llvm::all_of(op->getResultTypes(), [](Type ty) {
119         ShapedType shaped_ty = ty.cast<ShapedType>();
120         Type element_ty = shaped_ty.getElementType();
121         return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 &&
122                element_ty.isIntOrFloat();
123       });
124   if (has_empty_numerical_results && op->isRegistered()) {
125     for (Type ty : op->getResultTypes()) {
126       auto shaped_ty = ty.cast<ShapedType>();
127       results.push_back(builder.create<TF::ConstOp>(
128           op->getLoc(),
129           DenseElementsAttr::get(shaped_ty, llvm::ArrayRef<Attribute>())));
130     }
131     return success();
132   }
133 
134   SmallVector<ElementsAttr, 4> inputs;
135   for (auto operand : op->getOperands()) {
136     auto preceding_const_op = operand.getDefiningOp<TF::ConstOp>();
137     if (preceding_const_op) {
138       inputs.push_back(preceding_const_op.value());
139       continue;
140     }
141 
142     Operation* preceding_op = operand.getDefiningOp();
143     int preceding_result_id = -1;
144     for (auto preceding_result : preceding_op->getResults()) {
145       if (operand == preceding_result) {
146         preceding_result_id = preceding_result.getResultNumber();
147         break;
148       }
149     }
150     llvm::SmallVector<Value> preceding_results;
151     if (failed(FoldOperation(ctx, builder, preceding_op, preceding_results))) {
152       return failure();
153     }
154     auto preceding_result = preceding_results[preceding_result_id];
155     preceding_const_op = preceding_result.getDefiningOp<TF::ConstOp>();
156     inputs.push_back(preceding_const_op.value());
157   }
158 
159   // Avoid overlapping folds with the same context.
160   static auto* mu = new tensorflow::mutex();
161   tensorflow::mutex_lock l(*mu);
162   SmallVector<Attribute, 8> constants;
163   if (failed(tensorflow::EvaluateOperation(op, inputs, ctx, &constants))) {
164     return failure();
165   }
166   for (const auto& constant : constants) {
167     results.push_back(builder.create<TF::ConstOp>(op->getLoc(), constant));
168   }
169   return success();
170 }
171 
InitializeTFRuntime()172 TFE_Context* InitializeTFRuntime() {
173   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
174       TF_NewStatus(), TF_DeleteStatus);
175   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
176       TFE_NewContextOptions(), TFE_DeleteContextOptions);
177   // Only initialize single CPU.
178   tensorflow::ConfigProto config_proto;
179   // This is conceptually equal to what we do in python/eager/context.py but
180   // with all GPU devices ignored and CPU only set to 1.
181   (*config_proto.mutable_device_count())["CPU"] = 1;
182   (*config_proto.mutable_device_count())["GPU"] = 0;
183   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
184       TF_NewBuffer(), TF_DeleteBuffer);
185   DCHECK(config->data == nullptr);
186 
187   // Copy config_proto into config.
188   {
189     const size_t proto_size = config_proto.ByteSizeLong();
190     void* buf = tensorflow::port::Malloc(proto_size);
191     if (buf == nullptr) {
192       LOG(ERROR) << "Failed to allocate memory to serialize ConfigProto "
193                     "while creating context options for constant folding";
194       return nullptr;
195     }
196     if (!config_proto.SerializeWithCachedSizesToArray(
197             static_cast<uint8_t*>(buf))) {
198       tensorflow::port::Free(buf);
199       LOG(ERROR) << "Unable to serialize ConfigProto while creating context "
200                     "options for constant folding";
201       return nullptr;
202     }
203     config->data = buf;
204     config->length = proto_size;
205     config->data_deallocator = [](void* data, size_t length) {
206       tensorflow::port::Free(data);
207     };
208   }
209 
210   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
211                               status.get());
212   if (TF_GetCode(status.get()) != TF_OK) {
213     LOG(ERROR) << "Failed to set context options for constant folding: "
214                << status.get();
215     return nullptr;
216   }
217 
218   // Input tensors are placed on the host CPU so use the explicit device
219   // policy to fail if no CPU kernels are available for the op.
220   TFE_ContextOptionsSetDevicePlacementPolicy(opts.get(),
221                                              TFE_DEVICE_PLACEMENT_EXPLICIT);
222   auto ctx = TFE_NewContext(opts.get(), status.get());
223   if (TF_GetCode(status.get()) != TF_OK) {
224     LOG(ERROR) << "Failed to create context for constant folding: "
225                << status.get();
226     return nullptr;
227   }
228   return ctx;
229 }
230 
ConstantFoldOpIfPossible(Operation * op)231 llvm::SmallVector<Value> ConstantFoldOpIfPossible(Operation* op) {
232   if (failed(IsOperationFoldable(op))) return op->getResults();
233 
234   static TFE_Context* ctx = InitializeTFRuntime();
235   if (!ctx) return op->getResults();
236 
237   OpBuilder builder(op);
238   llvm::SmallVector<Value> results;
239   if (failed(FoldOperation(ctx, builder, op, results))) {
240     return op->getResults();
241   }
242   return results;
243 }
244 
245 }  // namespace quant
246 }  // namespace mlir
247