1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h" 17 18 #include "mlir/IR/OpDefinition.h" // from @llvm-project 19 20 namespace mlir { 21 namespace TFL { 22 namespace tac { 23 NotTFLQuantDequantizeOp(Operation * op)24bool NotTFLQuantDequantizeOp(Operation* op) { 25 if (!op) return false; 26 if (llvm::isa<TFL::QuantizeOp, TFL::DequantizeOp>(op)) return false; 27 return true; 28 } 29 IsTerminatorOp(Operation * op)30bool IsTerminatorOp(Operation* op) { 31 if (!op) return false; 32 return op->hasTrait<OpTrait::IsTerminator>(); 33 } 34 35 // Try to guess the inference type of the op. GetInferenceType(Operation * op)36InferenceType GetInferenceType(Operation* op) { 37 bool float_type_observed = false; 38 bool int8_type_observed = false; 39 bool uint8_type_observed = false; 40 for (auto& input : op->getOpOperands()) { 41 auto input_type = input.get().getType(); 42 if (IsF32ShapedType(input_type)) { 43 float_type_observed = true; 44 } else if (IsQI8Type(input_type)) { 45 int8_type_observed = true; 46 } else if (IsQUI8Type(input_type)) { 47 uint8_type_observed = true; 48 } 49 } 50 51 // We should not observe both uint8 & int8. 52 if (int8_type_observed && uint8_type_observed) return UNKNOWN; 53 54 if (float_type_observed) { 55 if (int8_type_observed || uint8_type_observed) { 56 return HYBRID; 57 } else { 58 return FLOAT; 59 } 60 } 61 62 if (int8_type_observed) { 63 return QUANTIZED_INT8; 64 } 65 66 if (uint8_type_observed) { 67 return QUANTIZED_UINT8; 68 } 69 70 // Default to float inference. 71 return FLOAT; 72 } 73 74 } // namespace tac 75 } // namespace TFL 76 } // namespace mlir 77