1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
17
18 #include <algorithm>
19 #include <ios>
20 #include <sstream>
21 #include <string>
22 #include <vector>
23
24 #include "absl/strings/numbers.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "tensorflow/core/framework/types.pb.h"
30
31 // Is this dtype a quantization type from TensorFlow.
IsQuantizationType(tensorflow::DataType dtype)32 static bool IsQuantizationType(tensorflow::DataType dtype) {
33 switch (dtype) {
34 case tensorflow::DT_QINT8:
35 case tensorflow::DT_QUINT8:
36 case tensorflow::DT_QINT16:
37 case tensorflow::DT_QUINT16:
38 case tensorflow::DT_QINT32:
39 return true;
40 default:
41 return false;
42 }
43 }
44
45 namespace mlir {
46 namespace quant {
47 namespace {
GetBooleanSpecs(const std::string & bool_val)48 bool GetBooleanSpecs(const std::string& bool_val) {
49 bool result;
50 std::stringstream iss(bool_val);
51 iss >> std::boolalpha >> result;
52 return result;
53 }
54 } // namespace
55
ParseCustomOpSpecs(absl::string_view node_names,const CustomOpUpdateOptions & update_option,CustomOpMap & custom_op_map)56 void ParseCustomOpSpecs(absl::string_view node_names,
57 const CustomOpUpdateOptions& update_option,
58 CustomOpMap& custom_op_map) {
59 if (node_names.empty()) return;
60
61 std::vector<std::string> custom_nodes = absl::StrSplit(node_names, ',');
62
63 for (auto& cur_node : custom_nodes) {
64 std::vector<std::string> node_infos = absl::StrSplit(cur_node, '=');
65 std::string node_name = node_infos[0];
66 auto node_specification = node_infos[1];
67 CustomOpInfo new_node_info;
68 switch (update_option) {
69 case CustomOpUpdateOptions::kINputIndices: {
70 std::vector<std::string> indices =
71 absl::StrSplit(node_specification, '-');
72 for (auto& cur_index : indices) {
73 custom_op_map[node_name].quantizable_input_indices.push_back(
74 std::stoi(cur_index));
75 }
76 break;
77 }
78 case CustomOpUpdateOptions::kWeightOnly:
79 custom_op_map[node_name].is_weight_only =
80 GetBooleanSpecs(node_specification);
81 break;
82 case CustomOpUpdateOptions::kNoSideEffect:
83 custom_op_map[node_name].no_side_effect =
84 GetBooleanSpecs(node_specification);
85 break;
86 }
87 }
88 }
89
ParseInputNodeQuantSpecs(absl::string_view node_names,absl::string_view min_values,absl::string_view max_values,absl::string_view inference_type,QuantizationSpecs * quant_specs)90 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
91 absl::string_view min_values,
92 absl::string_view max_values,
93 absl::string_view inference_type,
94 QuantizationSpecs* quant_specs) {
95 std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
96 std::vector<llvm::Optional<double>> node_mins;
97 if (!min_values.empty()) {
98 std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
99 for (int i = 0, e = node_mins_str.size(); i < e; i++) {
100 double value;
101 if (!absl::SimpleAtod(node_mins_str[i], &value)) {
102 return true;
103 }
104 node_mins.push_back(value);
105 }
106 }
107
108 std::vector<llvm::Optional<double>> node_maxs;
109 if (!max_values.empty()) {
110 std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
111 for (int i = 0, e = node_maxs_str.size(); i < e; i++) {
112 double value;
113 if (!absl::SimpleAtod(node_maxs_str[i], &value)) {
114 llvm::errs() << "Unexpected mins: " << node_maxs_str[i] << "\n";
115 return true;
116 }
117 node_maxs.push_back(value);
118 }
119 }
120
121 tensorflow::DataType final_type = tensorflow::DT_FLOAT;
122 if (!inference_type.empty() &&
123 !DataType_Parse(std::string(inference_type), &final_type)) {
124 return true;
125 }
126 return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type,
127 quant_specs);
128 }
129
GetInputNodeQuantSpecs(const std::vector<std::string> & node_names,const std::vector<llvm::Optional<double>> & node_mins,const std::vector<llvm::Optional<double>> & node_maxs,tensorflow::DataType inference_type,QuantizationSpecs * quant_specs)130 bool GetInputNodeQuantSpecs(
131 const std::vector<std::string>& node_names,
132 const std::vector<llvm::Optional<double>>& node_mins,
133 const std::vector<llvm::Optional<double>>& node_maxs,
134 tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
135 quant_specs->inference_type = inference_type;
136
137 // If min/max are not specified, just return;
138 if (node_mins.empty() || node_maxs.empty()) return false;
139
140 // Otherwise make sure min/max has the same size as inputs.
141 if (IsQuantizationType(inference_type)) {
142 // min/max should have same size as inputs, or shouldn't be specified.
143 if (node_names.size() != node_mins.size() ||
144 node_names.size() != node_maxs.size()) {
145 return true;
146 }
147 for (int i = 0, e = node_names.size(); i != e; ++i) {
148 quant_specs->input_ranges.push_back({node_mins[i], node_maxs[i]});
149 }
150 return false;
151 }
152 if (!node_mins.empty()) {
153 llvm::dbgs() << "Ignored input_min_values.";
154 }
155 if (!node_maxs.empty()) {
156 llvm::dbgs() << "Ignored input_max_values.";
157 }
158 return false;
159 }
160
161 } // namespace quant
162 } // namespace mlir
163