xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
17 
18 #include <algorithm>
19 #include <ios>
20 #include <sstream>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/numbers.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 
31 // Is this dtype a quantization type from TensorFlow.
IsQuantizationType(tensorflow::DataType dtype)32 static bool IsQuantizationType(tensorflow::DataType dtype) {
33   switch (dtype) {
34     case tensorflow::DT_QINT8:
35     case tensorflow::DT_QUINT8:
36     case tensorflow::DT_QINT16:
37     case tensorflow::DT_QUINT16:
38     case tensorflow::DT_QINT32:
39       return true;
40     default:
41       return false;
42   }
43 }
44 
45 namespace mlir {
46 namespace quant {
47 namespace {
GetBooleanSpecs(const std::string & bool_val)48 bool GetBooleanSpecs(const std::string& bool_val) {
49   bool result;
50   std::stringstream iss(bool_val);
51   iss >> std::boolalpha >> result;
52   return result;
53 }
54 }  // namespace
55 
ParseCustomOpSpecs(absl::string_view node_names,const CustomOpUpdateOptions & update_option,CustomOpMap & custom_op_map)56 void ParseCustomOpSpecs(absl::string_view node_names,
57                         const CustomOpUpdateOptions& update_option,
58                         CustomOpMap& custom_op_map) {
59   if (node_names.empty()) return;
60 
61   std::vector<std::string> custom_nodes = absl::StrSplit(node_names, ',');
62 
63   for (auto& cur_node : custom_nodes) {
64     std::vector<std::string> node_infos = absl::StrSplit(cur_node, '=');
65     std::string node_name = node_infos[0];
66     auto node_specification = node_infos[1];
67     CustomOpInfo new_node_info;
68     switch (update_option) {
69       case CustomOpUpdateOptions::kINputIndices: {
70         std::vector<std::string> indices =
71             absl::StrSplit(node_specification, '-');
72         for (auto& cur_index : indices) {
73           custom_op_map[node_name].quantizable_input_indices.push_back(
74               std::stoi(cur_index));
75         }
76         break;
77       }
78       case CustomOpUpdateOptions::kWeightOnly:
79         custom_op_map[node_name].is_weight_only =
80             GetBooleanSpecs(node_specification);
81         break;
82       case CustomOpUpdateOptions::kNoSideEffect:
83         custom_op_map[node_name].no_side_effect =
84             GetBooleanSpecs(node_specification);
85         break;
86     }
87   }
88 }
89 
ParseInputNodeQuantSpecs(absl::string_view node_names,absl::string_view min_values,absl::string_view max_values,absl::string_view inference_type,QuantizationSpecs * quant_specs)90 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
91                               absl::string_view min_values,
92                               absl::string_view max_values,
93                               absl::string_view inference_type,
94                               QuantizationSpecs* quant_specs) {
95   std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
96   std::vector<llvm::Optional<double>> node_mins;
97   if (!min_values.empty()) {
98     std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
99     for (int i = 0, e = node_mins_str.size(); i < e; i++) {
100       double value;
101       if (!absl::SimpleAtod(node_mins_str[i], &value)) {
102         return true;
103       }
104       node_mins.push_back(value);
105     }
106   }
107 
108   std::vector<llvm::Optional<double>> node_maxs;
109   if (!max_values.empty()) {
110     std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
111     for (int i = 0, e = node_maxs_str.size(); i < e; i++) {
112       double value;
113       if (!absl::SimpleAtod(node_maxs_str[i], &value)) {
114         llvm::errs() << "Unexpected mins: " << node_maxs_str[i] << "\n";
115         return true;
116       }
117       node_maxs.push_back(value);
118     }
119   }
120 
121   tensorflow::DataType final_type = tensorflow::DT_FLOAT;
122   if (!inference_type.empty() &&
123       !DataType_Parse(std::string(inference_type), &final_type)) {
124     return true;
125   }
126   return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type,
127                                 quant_specs);
128 }
129 
GetInputNodeQuantSpecs(const std::vector<std::string> & node_names,const std::vector<llvm::Optional<double>> & node_mins,const std::vector<llvm::Optional<double>> & node_maxs,tensorflow::DataType inference_type,QuantizationSpecs * quant_specs)130 bool GetInputNodeQuantSpecs(
131     const std::vector<std::string>& node_names,
132     const std::vector<llvm::Optional<double>>& node_mins,
133     const std::vector<llvm::Optional<double>>& node_maxs,
134     tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
135   quant_specs->inference_type = inference_type;
136 
137   // If min/max are not specified, just return;
138   if (node_mins.empty() || node_maxs.empty()) return false;
139 
140   // Otherwise make sure min/max has the same size as inputs.
141   if (IsQuantizationType(inference_type)) {
142     // min/max should have same size as inputs, or shouldn't be specified.
143     if (node_names.size() != node_mins.size() ||
144         node_names.size() != node_maxs.size()) {
145       return true;
146     }
147     for (int i = 0, e = node_names.size(); i != e; ++i) {
148       quant_specs->input_ranges.push_back({node_mins[i], node_maxs[i]});
149     }
150     return false;
151   }
152   if (!node_mins.empty()) {
153     llvm::dbgs() << "Ignored input_min_values.";
154   }
155   if (!node_maxs.empty()) {
156     llvm::dbgs() << "Ignored input_max_values.";
157   }
158   return false;
159 }
160 
161 }  // namespace quant
162 }  // namespace mlir
163