1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/ADT/Twine.h"
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
32 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
37 #include "tensorflow/lite/schema/schema_generated.h"
38 #include "tensorflow/lite/schema/schema_utils.h"
39
40 namespace {
41
42 using ::tensorflow::Status;
43 using ::tensorflow::errors::InvalidArgument;
44 using ::xla::StatusOr;
45
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)46 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
47 mlir::Builder builder,
48 mlir::Location loc) {
49 auto padding = tflite::Padding::Padding_VALID;
50 if (pad_params == TfLitePadding::kTfLitePaddingSame) {
51 padding = tflite::Padding_SAME;
52 } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
53 padding = tflite::Padding_VALID;
54 } else {
55 return InvalidArgument(
56 absl::StrCat("Invalid padding type", std::to_string(pad_params)));
57 }
58
59 const char* option_name = tflite::EnumNamePadding(padding);
60 return builder.getStringAttr(option_name);
61 }
62
63 } // namespace
64
GetMlirOpNameFromOpCode(const tflite::OperatorCodeT & op_code)65 std::string mlir::GetMlirOpNameFromOpCode(
66 const tflite::OperatorCodeT& op_code) {
67 auto builtin_code = tflite::GetBuiltinCode(&op_code);
68 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
69 return std::string("tfl.custom");
70 }
71 if (builtin_code == tflite::BuiltinOperator_IF) {
72 return std::string("tf.If");
73 }
74 if (builtin_code == tflite::BuiltinOperator_WHILE) {
75 return std::string("tfl.while");
76 }
77
78 llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
79 return llvm::Twine("tfl.", op_name.lower()).str();
80 }
81
82 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
83 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)84 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
85 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
86 return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
87 .Case("NONE", tflite::ActivationFunctionType_NONE)
88 .Case("RELU", tflite::ActivationFunctionType_RELU)
89 .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
90 .Case("RELU6", tflite::ActivationFunctionType_RELU6)
91 .Case("TANH", tflite::ActivationFunctionType_TANH)
92 .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
93 }
94
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)95 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
96 tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
97 if (type == tflite::TensorType_INT64) {
98 return tflite::TensorType_INT64;
99 } else if (type == tflite::TensorType_INT32) {
100 return tflite::TensorType_INT32;
101 }
102 llvm_unreachable("invalid type in conversion.");
103 }
104
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)105 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
106 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
107 return llvm::StringSwitch<tflite::Padding>(str)
108 .Case("SAME", tflite::Padding_SAME)
109 .Case("VALID", tflite::Padding_VALID);
110 }
111
ConvertTFL_MirrorPaddingAttrForOptionWriter(mlir::TFL::MirrorPaddingType padding,flatbuffers::FlatBufferBuilder * builder)112 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
113 mlir::TFL::MirrorPaddingType padding,
114 flatbuffers::FlatBufferBuilder* builder) {
115 switch (padding) {
116 case mlir::TFL::MirrorPaddingType::REFLECT:
117 return tflite::MirrorPadMode_REFLECT;
118 case mlir::TFL::MirrorPaddingType::SYMMETRIC:
119 return tflite::MirrorPadMode_SYMMETRIC;
120 }
121 llvm_unreachable("invalid mirror_pad_enum in conversion.");
122 }
123
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)124 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
125 mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
126 return tflite::ConvertTypeToTensorType(type);
127 }
128
129 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)130 static int ConvertI32AttrForOptionWriter(
131 int i, flatbuffers::FlatBufferBuilder* builder) {
132 return i;
133 }
134
135 // I64Attr already returns a int64_t as required by flatbuffer builders.
ConvertI64AttrForOptionWriter(int64_t i,flatbuffers::FlatBufferBuilder * builder)136 static int64_t ConvertI64AttrForOptionWriter(
137 int64_t i, flatbuffers::FlatBufferBuilder* builder) {
138 return i;
139 }
140
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)141 static int ConvertPositiveI32AttrForOptionWriter(
142 int i, flatbuffers::FlatBufferBuilder* builder) {
143 return ConvertI32AttrForOptionWriter(i, builder);
144 }
145
146 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)147 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
148 flatbuffers::FlatBufferBuilder* builder) {
149 std::vector<int32_t> intVec;
150 intVec.reserve(attrArray.getValue().size());
151 for (auto attr : attrArray.getValue()) {
152 intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
153 }
154 return builder->CreateVector(intVec);
155 }
156
157 static flatbuffers::Offset<flatbuffers::Vector<float>>
ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)158 ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
159 flatbuffers::FlatBufferBuilder* builder) {
160 std::vector<float> floatVec;
161 floatVec.reserve(attrArray.getValue().size());
162 for (auto attr : attrArray.getValue()) {
163 floatVec.push_back(
164 attr.cast<mlir::FloatAttr>().getValue().convertToFloat());
165 }
166 return builder->CreateVector(floatVec);
167 }
168
169 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)170 static float ConvertF32AttrForOptionWriter(
171 llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
172 return f.convertToFloat();
173 }
174
175 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)176 static bool ConvertBoolAttrForOptionWriter(
177 bool b, flatbuffers::FlatBufferBuilder* builder) {
178 return b;
179 }
180
181 // Overloading of ConvertBoolAttrForOptionWriter which takes Optional<bool> as
182 // an input. If value is not specified, false is set for the attribute.
ConvertBoolAttrForOptionWriter(mlir::Optional<bool> b,flatbuffers::FlatBufferBuilder * builder)183 static bool ConvertBoolAttrForOptionWriter(
184 mlir::Optional<bool> b, flatbuffers::FlatBufferBuilder* builder) {
185 return b.has_value() ? b.getValue() : false;
186 }
187
ConvertStrAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)188 static flatbuffers::Offset<flatbuffers::String> ConvertStrAttrForOptionWriter(
189 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
190 return builder->CreateString(str.str());
191 }
192
ConvertTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)193 static tflite::TensorType ConvertTypeAttrForOptionWriter(
194 mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
195 return tflite::ConvertTypeToTensorType(type);
196 }
197
198 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)199 ConvertDerivedShapeAttrForOptionWriter(
200 llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
201 std::vector<int> intVec(r.begin(), r.end());
202 return builder->CreateVector(intVec);
203 }
204
205 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)206 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
207 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
208 return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
209 .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
210 .Case("SHUFFLED4x16INT8",
211 tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
212 }
213
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(mlir::TFL::LSTMKernelType kernel_type,flatbuffers::FlatBufferBuilder * builder)214 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
215 mlir::TFL::LSTMKernelType kernel_type,
216 flatbuffers::FlatBufferBuilder* builder) {
217 switch (kernel_type) {
218 case mlir::TFL::LSTMKernelType::FULL:
219 return tflite::LSTMKernelType_FULL;
220 case mlir::TFL::LSTMKernelType::BASIC:
221 return tflite::LSTMKernelType_BASIC;
222 }
223 llvm_unreachable("invalid lstm_kernel_type in conversion.");
224 }
225
BuildBoolAttr(bool value,mlir::Builder builder)226 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
227 return builder.getBoolAttr(value);
228 }
229
BuildStrAttr(llvm::StringRef str,mlir::Builder builder)230 static mlir::Attribute BuildStrAttr(llvm::StringRef str,
231 mlir::Builder builder) {
232 return builder.getStringAttr(str);
233 }
234
BuildF32Attr(float value,mlir::Builder builder)235 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
236 return builder.getF32FloatAttr(value);
237 }
238
BuildI32Attr(int32_t value,mlir::Builder builder)239 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
240 return builder.getI32IntegerAttr(value);
241 }
242
BuildI64Attr(int64_t value,mlir::Builder builder)243 static mlir::Attribute BuildI64Attr(int64_t value, mlir::Builder builder) {
244 return builder.getI64IntegerAttr(value);
245 }
246
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)247 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
248 mlir::Builder builder) {
249 std::vector<int64_t> typecast(value.begin(), value.end());
250 return builder.getI64ArrayAttr(typecast);
251 }
252
BuildF32ArrayAttr(std::vector<float> value,mlir::Builder builder)253 static mlir::Attribute BuildF32ArrayAttr(std::vector<float> value,
254 mlir::Builder builder) {
255 std::vector<float> typecast(value.begin(), value.end());
256 return builder.getF32ArrayAttr(typecast);
257 }
258
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)259 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
260 mlir::Builder builder) {
261 return builder.getI32IntegerAttr(value);
262 }
263
BuildTypeAttr(tflite::TensorType value,mlir::Builder builder)264 static mlir::Attribute BuildTypeAttr(tflite::TensorType value,
265 mlir::Builder builder) {
266 return mlir::TypeAttr::get(ConvertElementType(value, builder));
267 }
268
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)269 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
270 mlir::Builder builder) {
271 const char* option_name = tflite::EnumNameActivationFunctionType(value);
272 return builder.getStringAttr(option_name);
273 }
274
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)275 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
276 tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
277 const char* option_name =
278 tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
279 return builder.getStringAttr(option_name);
280 }
281
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)282 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
283 mlir::Builder builder) {
284 mlir::TFL::LSTMKernelType kernel_type;
285 switch (value) {
286 case tflite::LSTMKernelType_FULL:
287 kernel_type = mlir::TFL::LSTMKernelType::FULL;
288 break;
289 case tflite::LSTMKernelType_BASIC:
290 kernel_type = mlir::TFL::LSTMKernelType::BASIC;
291 break;
292 }
293 return mlir::TFL::LSTMKernelTypeAttr::get(builder.getContext(), kernel_type);
294 }
295
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)296 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
297 mlir::Builder builder) {
298 mlir::TFL::MirrorPaddingType padding;
299 switch (value) {
300 case tflite::MirrorPadMode_REFLECT:
301 padding = mlir::TFL::MirrorPaddingType::REFLECT;
302 break;
303 case tflite::MirrorPadMode_SYMMETRIC:
304 default:
305 padding = mlir::TFL::MirrorPaddingType::SYMMETRIC;
306 break;
307 }
308 return mlir::TFL::MirrorPaddingTypeAttr::get(builder.getContext(), padding);
309 }
310
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)311 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
312 mlir::Builder builder) {
313 const char* option_name = tflite::EnumNamePadding(value);
314 return builder.getStringAttr(option_name);
315 }
316
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)317 Status mlir::CustomOptionsToAttributes(
318 const std::string& custom_code, const std::vector<uint8_t>& custom_options,
319 mlir::Builder builder, mlir::Location loc,
320 llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
321 attributes->emplace_back(
322 builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
323 std::string content;
324 content.assign(reinterpret_cast<const char*>(custom_options.data()),
325 custom_options.size());
326 attributes->emplace_back(builder.getNamedAttr(
327 "custom_option",
328 mlir::TFL::ConstBytesAttr::get(builder.getContext(), content)));
329
330 return ::tensorflow::OkStatus();
331 }
332
333 // Pull in FlatBuffer writers for TFLite generated using TableGen
334 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
335