xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tflite/export.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17 
18 #include <string>
19 
20 #include "tensorflow/lite/toco/model.h"
21 #include "tensorflow/lite/toco/tflite/operator.h"
22 #include "tensorflow/lite/util.h"
23 
24 namespace toco {
25 
26 namespace tflite {
27 
28 enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
29 
30 // The parameters for exporting a TFLite model.
31 struct ExportParams {
32   bool allow_custom_ops = false;
33   bool allow_dynamic_tensors = true;
34   bool enable_select_tf_ops = false;
35   QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
36   // Whether to use per-tensor (false) or per-channel (true) for hybrid quant.
37   bool disable_per_channel = false;
38 };
39 
40 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
41 // result in the given string.
42 tensorflow::Status Export(const Model& model, std::string* output_file_contents,
43                           const ExportParams& params);
44 
45 // Export API with custom TFLite operator mapping.
46 tensorflow::Status Export(
47     const Model& model, std::string* output_file_contents,
48     const ExportParams& params,
49     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
50 
51 // This is for backward-compatibility.
52 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents)53 inline void Export(const Model& model, bool allow_custom_ops,
54                    bool quantize_weights, std::string* output_file_contents) {
55   ExportParams params;
56   params.allow_custom_ops = allow_custom_ops;
57   params.quantize_weights =
58       quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
59   auto status = Export(model, output_file_contents, params);
60   if (!status.ok()) LOG(QFATAL) << status.error_message();
61 }
62 
63 // This is for backward-compatibility.
64 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)65 inline void Export(
66     const Model& model, bool allow_custom_ops, bool quantize_weights,
67     std::string* output_file_contents,
68     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
69   ExportParams params;
70   params.allow_custom_ops = allow_custom_ops;
71   params.quantize_weights =
72       quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
73   auto status = Export(model, output_file_contents, params, ops_by_type);
74   if (!status.ok()) LOG(QFATAL) << status.error_message();
75 }
76 
77 // This is for backward-compatibility.
78 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,std::string * output_file_contents)79 inline void Export(const Model& model, std::string* output_file_contents) {
80   ExportParams params;
81   params.allow_custom_ops = true;
82   auto status = Export(model, output_file_contents, params);
83   if (!status.ok()) LOG(QFATAL) << status.error_message();
84 }
85 
86 namespace details {
87 
88 // A map from tensor name to its final position in the TF Lite buffer.
89 using TensorsMap = std::unordered_map<std::string, int>;
90 
91 // A key to identify an operator.
92 // Only when `type` is `kUnsupported`, `custom_code` is filled to
93 // identify which operation is used.
94 class OperatorKey {
95  public:
OperatorKey()96   OperatorKey() {}
97 
98   // Construct OperatorKey by Toco op.
99   OperatorKey(
100       const ::toco::OperatorSignature& op_signature,
101       const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
102       bool enable_select_tf_ops);
103 
104   // Construct OperatorKey by type, custom code and version.
105   // Note that this construct doesn't set the additional information including
106   // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)107   OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
108               int version)
109       : type_(type), custom_code_(custom_code), version_(version) {}
110 
111   // Only `type`, `custom_code` and `version` is used to compute hash and
112   // identity.
type()113   ::tflite::BuiltinOperator type() const { return type_; }
custom_code()114   const std::string& custom_code() const { return custom_code_; }
version()115   int version() const { return version_; }
116 
117   // The attributes below are not used to compute hash and identity.
118   //
119   // Return true if the op is a custom op. Note it will return false for Flex
120   // ops.
is_custom_op()121   bool is_custom_op() const { return is_custom_op_; }
122   // Return true if the op is a Flex op.
is_flex_op()123   bool is_flex_op() const { return is_flex_op_; }
124   // Return true if the op is a Flex op but it's knwon that the op is not
125   // supported by Flex runtime.
is_unsupported_flex_op()126   bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
127   // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()128   const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
129 
130   bool operator<(const OperatorKey& other) const {
131     if (type_ < other.type_)
132       return true;
133     else if (type_ > other.type_)
134       return false;
135     else if (custom_code_ < other.custom_code_)
136       return true;
137     else if (custom_code_ > other.custom_code_)
138       return false;
139     else
140       return version_ < other.version_;
141   }
142 
143   bool operator==(const OperatorKey& other) const {
144     return type_ == other.type_ && custom_code_ == other.custom_code_ &&
145            version_ == other.version_;
146   }
147 
148   struct Hash {
operatorHash149     size_t operator()(const OperatorKey& key) const {
150       return ::tflite::CombineHashes(
151           {std::hash<size_t>()(static_cast<size_t>(key.type())),
152            std::hash<std::string>()(key.custom_code()),
153            std::hash<int>()(key.version())});
154     }
155   };
156 
157  private:
158   ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
159   std::string custom_code_;
160   int version_ = 1;
161 
162   bool is_custom_op_ = false;
163   bool is_flex_op_ = false;
164   bool is_unsupported_flex_op_ = false;
165   // The original TensorFlow op name for the flex op. Filled only when
166   // `is_flex_op` is true.
167   std::string flex_tensorflow_op_;
168 };
169 
170 // A map from OperatorKey to its final position in the TF Lite buffer.
171 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
172 
173 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
174 void LoadOperatorsMap(
175     const Model& model, OperatorsMap* operators_map,
176     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
177     bool enable_select_tf_ops);
178 
179 }  // namespace details
180 }  // namespace tflite
181 }  // namespace toco
182 
183 #endif  // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
184