1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iostream>
17 #include <memory>
18
19 #include "absl/strings/string_view.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/InitLLVM.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/PrettyStackTrace.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28
29 using llvm::Optional;
30 using llvm::cl::opt;
31
32 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
33 // RUN: | %p/importer_test_min_max - \
34 // RUN: | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - \
35 // RUN: | FileCheck %s
36
37 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
38 // RUN: | %p/importer_test_min_max - \
39 // RUN: | flatbuffer_to_string - \
40 // RUN: | FileCheck --check-prefix=FB %s
41
42 // Tests for verifying the tflite model with min/max can be imported
43 // correctly.
44
45 // NOLINTNEXTLINE
46 static opt<std::string> inputFileName(llvm::cl::Positional,
47 llvm::cl::desc("<input file>"),
48 llvm::cl::init("-"));
49
50 namespace mlir {
51 namespace {
InjectStatsToFullyConnected(llvm::StringRef buffer)52 Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
53 llvm::StringRef buffer) {
54 auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
55 buffer.data(), buffer.size());
56 if (nullptr == model_ptr) {
57 return llvm::None;
58 }
59 std::unique_ptr<tflite::ModelT> model(model_ptr->GetModel()->UnPack());
60
61 // FB-LABEL: name: "arg0",
62 // FB-NEXT: quantization: {
63 // FB-NEXT: min: [ -1.0 ],
64 // FB-NEXT: max: [ 1.0 ]
65 // FB-NEXT: }
66
67 // FB-LABEL: name: "arg1",
68 // FB-NEXT: quantization: {
69 // FB-EMPTY:
70 // FB-NEXT: }
71
72 // FB-LABEL: name: "tfl.fully_connected",
73 // FB-NEXT: quantization: {
74 // FB-NEXT: min: [ -0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0,
75 // FB-SAME: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0, -16.0,
76 // FB-SAME: -17.0, -18.0, -19.0, -20.0, -21.0, -22.0, -23.0, -24.0, -25.0,
77 // FB-SAME: -26.0, -27.0, -28.0, -29.0, -30.0, -31.0, -32.0, -33.0, -34.0,
78 // FB-SAME: -35.0, -36.0, -37.0, -38.0, -39.0 ],
79 // FB-NEXT: max: [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
80 // FB-SAME: 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
81 // FB-SAME: 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
82 // FB-SAME: 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0 ],
83 // FB-NEXT: quantized_dimension: 1
84 // FB-NEXT: }
85
86 // FB-LABEL: name: "tfl.fully_connected:1",
87 // FB-NEXT: quantization: {
88 // FB-EMPTY:
89 // FB-NEXT: }
90
91 // FB-LABEL: operators: [ {
92 // FB-NEXT: inputs: [ 0, 1, 2 ],
93 // FB-NEXT: outputs: [ 3, 4 ],
94 // FB-NEXT: builtin_options_type: FullyConnectedOptions,
95 // FB-NEXT: builtin_options: {
96 // FB-EMPTY:
97 // FB-NEXT: }
98 // FB-NEXT: } ],
99
100 // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
101 // CHECK-SAME: -> tensor<40x40xf32>
102 // CHECK: %[[stat:.*]] = "quantfork.stats"(%arg0) {layerStats = dense<
103 // CHECK-SAME: [-1.000000e+00, 1.000000e+00]> : tensor<2xf32>}
104 // CHECK-SAME: : (tensor<40x37xf32>) -> tensor<40x37xf32>
105 // CHECK-NEXT: %[[cst:.*]] = "tfl.pseudo_const"() {value = dense<
106 // CHECK-SAME: 1.000000e+00> : tensor<40xf32>} : () -> tensor<40xf32>
107 // CHECK-NEXT: %[[fc:.*]]:2 = "tfl.fully_connected"(%[[stat]], %arg1,
108 // CHECK-NEXT: %[[stat1:.*]] = "quantfork.stats"(%[[fc]]#0)
109 // CHECK-SAME: {axis = 1 : i64,
110 // CHECK-SAME: axisStats = dense<{{\[}}[-0.000000e+00, 0.000000e+00],
111 // CHECK-SAME: [-1.000000e+00, 1.000000e+00],
112 // CHECK-SAME: [-2.000000e+00, 2.000000e+00]
113 // CHECK-NEXT: return %[[stat1]] : tensor<40x40xf32>
114 // CHECK-NEXT: }
115
116 // Find the tensors and inject the min and max to the input and output
117 for (auto& sub_graph : model->subgraphs) {
118 for (auto& op : sub_graph->operators) {
119 if (tflite::GetBuiltinCode(
120 model->operator_codes[op->opcode_index].get()) ==
121 tflite::BuiltinOperator_FULLY_CONNECTED) {
122 // inject min/max to the input and output tensors
123 auto& input_tensor = sub_graph->tensors[op->inputs[0]];
124 input_tensor->quantization->scale.clear();
125 input_tensor->quantization->zero_point.clear();
126 input_tensor->quantization->min.push_back(-1.0);
127 input_tensor->quantization->max.push_back(1.0);
128
129 auto& output_tensor = sub_graph->tensors[op->outputs[0]];
130 auto shape = output_tensor->shape;
131 output_tensor->quantization->scale.clear();
132 output_tensor->quantization->zero_point.clear();
133 for (int i = 0; i < shape.back(); ++i) {
134 output_tensor->quantization->min.push_back(-1.0 * i);
135 output_tensor->quantization->max.push_back(1.0 * i);
136 }
137 output_tensor->quantization->quantized_dimension = shape.size() - 1;
138 }
139 }
140 }
141 return model;
142 }
143
144 } // namespace
145 } // namespace mlir
146
main(int argc,char ** argv)147 int main(int argc, char** argv) {
148 llvm::InitLLVM y(argc, argv);
149 llvm::cl::ParseCommandLineOptions(argc, argv);
150 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
151 if (std::error_code error = file_or_err.getError()) {
152 llvm::errs() << argv[0] << ": could not open input file '" << inputFileName
153 << "': " << error.message() << "\n";
154 return 1;
155 }
156 auto buffer = file_or_err->get();
157 auto maybe_module =
158 mlir::InjectStatsToFullyConnected(buffer->getBuffer().str());
159 if (!maybe_module.has_value()) {
160 return 1;
161 }
162 flatbuffers::FlatBufferBuilder builder;
163 flatbuffers::Offset<tflite::Model> output_model_location =
164 tflite::Model::Pack(builder, maybe_module.getValue().get());
165 tflite::FinishModelBuffer(builder, output_model_location);
166 std::string output_model_content(
167 reinterpret_cast<const char*>(builder.GetBufferPointer()),
168 builder.GetSize());
169 std::cout << output_model_content << "\n";
170 return 0;
171 }
172