1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17
18 #include <string>
19
20 #include "tensorflow/lite/toco/model.h"
21 #include "tensorflow/lite/toco/tflite/operator.h"
22 #include "tensorflow/lite/util.h"
23
24 namespace toco {
25
26 namespace tflite {
27
28 enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
29
30 // The parameters for exporting a TFLite model.
31 struct ExportParams {
32 bool allow_custom_ops = false;
33 bool allow_dynamic_tensors = true;
34 bool enable_select_tf_ops = false;
35 QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
36 // Whether to use per-tensor (false) or per-channel (true) for hybrid quant.
37 bool disable_per_channel = false;
38 };
39
40 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
41 // result in the given string.
42 tensorflow::Status Export(const Model& model, std::string* output_file_contents,
43 const ExportParams& params);
44
45 // Export API with custom TFLite operator mapping.
46 tensorflow::Status Export(
47 const Model& model, std::string* output_file_contents,
48 const ExportParams& params,
49 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
50
51 // This is for backward-compatibility.
52 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents)53 inline void Export(const Model& model, bool allow_custom_ops,
54 bool quantize_weights, std::string* output_file_contents) {
55 ExportParams params;
56 params.allow_custom_ops = allow_custom_ops;
57 params.quantize_weights =
58 quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
59 auto status = Export(model, output_file_contents, params);
60 if (!status.ok()) LOG(QFATAL) << status.error_message();
61 }
62
63 // This is for backward-compatibility.
64 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)65 inline void Export(
66 const Model& model, bool allow_custom_ops, bool quantize_weights,
67 std::string* output_file_contents,
68 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
69 ExportParams params;
70 params.allow_custom_ops = allow_custom_ops;
71 params.quantize_weights =
72 quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
73 auto status = Export(model, output_file_contents, params, ops_by_type);
74 if (!status.ok()) LOG(QFATAL) << status.error_message();
75 }
76
77 // This is for backward-compatibility.
78 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,std::string * output_file_contents)79 inline void Export(const Model& model, std::string* output_file_contents) {
80 ExportParams params;
81 params.allow_custom_ops = true;
82 auto status = Export(model, output_file_contents, params);
83 if (!status.ok()) LOG(QFATAL) << status.error_message();
84 }
85
86 namespace details {
87
88 // A map from tensor name to its final position in the TF Lite buffer.
89 using TensorsMap = std::unordered_map<std::string, int>;
90
91 // A key to identify an operator.
92 // Only when `type` is `kUnsupported`, `custom_code` is filled to
93 // identify which operation is used.
94 class OperatorKey {
95 public:
OperatorKey()96 OperatorKey() {}
97
98 // Construct OperatorKey by Toco op.
99 OperatorKey(
100 const ::toco::OperatorSignature& op_signature,
101 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
102 bool enable_select_tf_ops);
103
104 // Construct OperatorKey by type, custom code and version.
105 // Note that this construct doesn't set the additional information including
106 // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)107 OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
108 int version)
109 : type_(type), custom_code_(custom_code), version_(version) {}
110
111 // Only `type`, `custom_code` and `version` is used to compute hash and
112 // identity.
type()113 ::tflite::BuiltinOperator type() const { return type_; }
custom_code()114 const std::string& custom_code() const { return custom_code_; }
version()115 int version() const { return version_; }
116
117 // The attributes below are not used to compute hash and identity.
118 //
119 // Return true if the op is a custom op. Note it will return false for Flex
120 // ops.
is_custom_op()121 bool is_custom_op() const { return is_custom_op_; }
122 // Return true if the op is a Flex op.
is_flex_op()123 bool is_flex_op() const { return is_flex_op_; }
124 // Return true if the op is a Flex op but it's knwon that the op is not
125 // supported by Flex runtime.
is_unsupported_flex_op()126 bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
127 // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()128 const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
129
130 bool operator<(const OperatorKey& other) const {
131 if (type_ < other.type_)
132 return true;
133 else if (type_ > other.type_)
134 return false;
135 else if (custom_code_ < other.custom_code_)
136 return true;
137 else if (custom_code_ > other.custom_code_)
138 return false;
139 else
140 return version_ < other.version_;
141 }
142
143 bool operator==(const OperatorKey& other) const {
144 return type_ == other.type_ && custom_code_ == other.custom_code_ &&
145 version_ == other.version_;
146 }
147
148 struct Hash {
operatorHash149 size_t operator()(const OperatorKey& key) const {
150 return ::tflite::CombineHashes(
151 {std::hash<size_t>()(static_cast<size_t>(key.type())),
152 std::hash<std::string>()(key.custom_code()),
153 std::hash<int>()(key.version())});
154 }
155 };
156
157 private:
158 ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
159 std::string custom_code_;
160 int version_ = 1;
161
162 bool is_custom_op_ = false;
163 bool is_flex_op_ = false;
164 bool is_unsupported_flex_op_ = false;
165 // The original TensorFlow op name for the flex op. Filled only when
166 // `is_flex_op` is true.
167 std::string flex_tensorflow_op_;
168 };
169
170 // A map from OperatorKey to its final position in the TF Lite buffer.
171 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
172
173 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
174 void LoadOperatorsMap(
175 const Model& model, OperatorsMap* operators_map,
176 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
177 bool enable_select_tf_ops);
178
179 } // namespace details
180 } // namespace tflite
181 } // namespace toco
182
183 #endif // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
184