xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/ADT/Twine.h"
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
32 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
37 #include "tensorflow/lite/schema/schema_generated.h"
38 #include "tensorflow/lite/schema/schema_utils.h"
39 
40 namespace {
41 
42 using ::tensorflow::Status;
43 using ::tensorflow::errors::InvalidArgument;
44 using ::xla::StatusOr;
45 
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)46 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
47                                           mlir::Builder builder,
48                                           mlir::Location loc) {
49   auto padding = tflite::Padding::Padding_VALID;
50   if (pad_params == TfLitePadding::kTfLitePaddingSame) {
51     padding = tflite::Padding_SAME;
52   } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
53     padding = tflite::Padding_VALID;
54   } else {
55     return InvalidArgument(
56         absl::StrCat("Invalid padding type", std::to_string(pad_params)));
57   }
58 
59   const char* option_name = tflite::EnumNamePadding(padding);
60   return builder.getStringAttr(option_name);
61 }
62 
63 }  // namespace
64 
GetMlirOpNameFromOpCode(const tflite::OperatorCodeT & op_code)65 std::string mlir::GetMlirOpNameFromOpCode(
66     const tflite::OperatorCodeT& op_code) {
67   auto builtin_code = tflite::GetBuiltinCode(&op_code);
68   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
69     return std::string("tfl.custom");
70   }
71   if (builtin_code == tflite::BuiltinOperator_IF) {
72     return std::string("tf.If");
73   }
74   if (builtin_code == tflite::BuiltinOperator_WHILE) {
75     return std::string("tfl.while");
76   }
77 
78   llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
79   return llvm::Twine("tfl.", op_name.lower()).str();
80 }
81 
82 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
83 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)84 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
85     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
86   return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
87       .Case("NONE", tflite::ActivationFunctionType_NONE)
88       .Case("RELU", tflite::ActivationFunctionType_RELU)
89       .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
90       .Case("RELU6", tflite::ActivationFunctionType_RELU6)
91       .Case("TANH", tflite::ActivationFunctionType_TANH)
92       .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
93 }
94 
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)95 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
96     tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
97   if (type == tflite::TensorType_INT64) {
98     return tflite::TensorType_INT64;
99   } else if (type == tflite::TensorType_INT32) {
100     return tflite::TensorType_INT32;
101   }
102   llvm_unreachable("invalid type in conversion.");
103 }
104 
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)105 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
106     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
107   return llvm::StringSwitch<tflite::Padding>(str)
108       .Case("SAME", tflite::Padding_SAME)
109       .Case("VALID", tflite::Padding_VALID);
110 }
111 
ConvertTFL_MirrorPaddingAttrForOptionWriter(mlir::TFL::MirrorPaddingType padding,flatbuffers::FlatBufferBuilder * builder)112 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
113     mlir::TFL::MirrorPaddingType padding,
114     flatbuffers::FlatBufferBuilder* builder) {
115   switch (padding) {
116     case mlir::TFL::MirrorPaddingType::REFLECT:
117       return tflite::MirrorPadMode_REFLECT;
118     case mlir::TFL::MirrorPaddingType::SYMMETRIC:
119       return tflite::MirrorPadMode_SYMMETRIC;
120   }
121   llvm_unreachable("invalid mirror_pad_enum in conversion.");
122 }
123 
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)124 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
125     mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
126   return tflite::ConvertTypeToTensorType(type);
127 }
128 
129 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)130 static int ConvertI32AttrForOptionWriter(
131     int i, flatbuffers::FlatBufferBuilder* builder) {
132   return i;
133 }
134 
135 // I64Attr already returns a int64_t as required by flatbuffer builders.
ConvertI64AttrForOptionWriter(int64_t i,flatbuffers::FlatBufferBuilder * builder)136 static int64_t ConvertI64AttrForOptionWriter(
137     int64_t i, flatbuffers::FlatBufferBuilder* builder) {
138   return i;
139 }
140 
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)141 static int ConvertPositiveI32AttrForOptionWriter(
142     int i, flatbuffers::FlatBufferBuilder* builder) {
143   return ConvertI32AttrForOptionWriter(i, builder);
144 }
145 
146 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)147 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
148                                    flatbuffers::FlatBufferBuilder* builder) {
149   std::vector<int32_t> intVec;
150   intVec.reserve(attrArray.getValue().size());
151   for (auto attr : attrArray.getValue()) {
152     intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
153   }
154   return builder->CreateVector(intVec);
155 }
156 
157 static flatbuffers::Offset<flatbuffers::Vector<float>>
ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)158 ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
159                                    flatbuffers::FlatBufferBuilder* builder) {
160   std::vector<float> floatVec;
161   floatVec.reserve(attrArray.getValue().size());
162   for (auto attr : attrArray.getValue()) {
163     floatVec.push_back(
164         attr.cast<mlir::FloatAttr>().getValue().convertToFloat());
165   }
166   return builder->CreateVector(floatVec);
167 }
168 
169 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)170 static float ConvertF32AttrForOptionWriter(
171     llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
172   return f.convertToFloat();
173 }
174 
175 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)176 static bool ConvertBoolAttrForOptionWriter(
177     bool b, flatbuffers::FlatBufferBuilder* builder) {
178   return b;
179 }
180 
181 // Overloading of ConvertBoolAttrForOptionWriter which takes Optional<bool> as
182 // an input. If value is not specified, false is set for the attribute.
ConvertBoolAttrForOptionWriter(mlir::Optional<bool> b,flatbuffers::FlatBufferBuilder * builder)183 static bool ConvertBoolAttrForOptionWriter(
184     mlir::Optional<bool> b, flatbuffers::FlatBufferBuilder* builder) {
185   return b.has_value() ? b.getValue() : false;
186 }
187 
ConvertStrAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)188 static flatbuffers::Offset<flatbuffers::String> ConvertStrAttrForOptionWriter(
189     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
190   return builder->CreateString(str.str());
191 }
192 
ConvertTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)193 static tflite::TensorType ConvertTypeAttrForOptionWriter(
194     mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
195   return tflite::ConvertTypeToTensorType(type);
196 }
197 
198 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)199 ConvertDerivedShapeAttrForOptionWriter(
200     llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
201   std::vector<int> intVec(r.begin(), r.end());
202   return builder->CreateVector(intVec);
203 }
204 
205 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)206 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
207     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
208   return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
209       .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
210       .Case("SHUFFLED4x16INT8",
211             tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
212 }
213 
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(mlir::TFL::LSTMKernelType kernel_type,flatbuffers::FlatBufferBuilder * builder)214 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
215     mlir::TFL::LSTMKernelType kernel_type,
216     flatbuffers::FlatBufferBuilder* builder) {
217   switch (kernel_type) {
218     case mlir::TFL::LSTMKernelType::FULL:
219       return tflite::LSTMKernelType_FULL;
220     case mlir::TFL::LSTMKernelType::BASIC:
221       return tflite::LSTMKernelType_BASIC;
222   }
223   llvm_unreachable("invalid lstm_kernel_type in conversion.");
224 }
225 
BuildBoolAttr(bool value,mlir::Builder builder)226 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
227   return builder.getBoolAttr(value);
228 }
229 
BuildStrAttr(llvm::StringRef str,mlir::Builder builder)230 static mlir::Attribute BuildStrAttr(llvm::StringRef str,
231                                     mlir::Builder builder) {
232   return builder.getStringAttr(str);
233 }
234 
BuildF32Attr(float value,mlir::Builder builder)235 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
236   return builder.getF32FloatAttr(value);
237 }
238 
BuildI32Attr(int32_t value,mlir::Builder builder)239 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
240   return builder.getI32IntegerAttr(value);
241 }
242 
BuildI64Attr(int64_t value,mlir::Builder builder)243 static mlir::Attribute BuildI64Attr(int64_t value, mlir::Builder builder) {
244   return builder.getI64IntegerAttr(value);
245 }
246 
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)247 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
248                                          mlir::Builder builder) {
249   std::vector<int64_t> typecast(value.begin(), value.end());
250   return builder.getI64ArrayAttr(typecast);
251 }
252 
BuildF32ArrayAttr(std::vector<float> value,mlir::Builder builder)253 static mlir::Attribute BuildF32ArrayAttr(std::vector<float> value,
254                                          mlir::Builder builder) {
255   std::vector<float> typecast(value.begin(), value.end());
256   return builder.getF32ArrayAttr(typecast);
257 }
258 
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)259 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
260                                             mlir::Builder builder) {
261   return builder.getI32IntegerAttr(value);
262 }
263 
BuildTypeAttr(tflite::TensorType value,mlir::Builder builder)264 static mlir::Attribute BuildTypeAttr(tflite::TensorType value,
265                                      mlir::Builder builder) {
266   return mlir::TypeAttr::get(ConvertElementType(value, builder));
267 }
268 
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)269 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
270                                        mlir::Builder builder) {
271   const char* option_name = tflite::EnumNameActivationFunctionType(value);
272   return builder.getStringAttr(option_name);
273 }
274 
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)275 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
276     tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
277   const char* option_name =
278       tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
279   return builder.getStringAttr(option_name);
280 }
281 
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)282 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
283                                                    mlir::Builder builder) {
284   mlir::TFL::LSTMKernelType kernel_type;
285   switch (value) {
286     case tflite::LSTMKernelType_FULL:
287       kernel_type = mlir::TFL::LSTMKernelType::FULL;
288       break;
289     case tflite::LSTMKernelType_BASIC:
290       kernel_type = mlir::TFL::LSTMKernelType::BASIC;
291       break;
292   }
293   return mlir::TFL::LSTMKernelTypeAttr::get(builder.getContext(), kernel_type);
294 }
295 
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)296 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
297                                                   mlir::Builder builder) {
298   mlir::TFL::MirrorPaddingType padding;
299   switch (value) {
300     case tflite::MirrorPadMode_REFLECT:
301       padding = mlir::TFL::MirrorPaddingType::REFLECT;
302       break;
303     case tflite::MirrorPadMode_SYMMETRIC:
304     default:
305       padding = mlir::TFL::MirrorPaddingType::SYMMETRIC;
306       break;
307   }
308   return mlir::TFL::MirrorPaddingTypeAttr::get(builder.getContext(), padding);
309 }
310 
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)311 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
312                                             mlir::Builder builder) {
313   const char* option_name = tflite::EnumNamePadding(value);
314   return builder.getStringAttr(option_name);
315 }
316 
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)317 Status mlir::CustomOptionsToAttributes(
318     const std::string& custom_code, const std::vector<uint8_t>& custom_options,
319     mlir::Builder builder, mlir::Location loc,
320     llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
321   attributes->emplace_back(
322       builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
323   std::string content;
324   content.assign(reinterpret_cast<const char*>(custom_options.data()),
325                  custom_options.size());
326   attributes->emplace_back(builder.getNamedAttr(
327       "custom_option",
328       mlir::TFL::ConstBytesAttr::get(builder.getContext(), content)));
329 
330   return ::tensorflow::OkStatus();
331 }
332 
333 // Pull in FlatBuffer writers for TFLite generated using TableGen
334 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
335