1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iostream>
17 #include <memory>
18 
19 #include "absl/strings/string_view.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/InitLLVM.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/PrettyStackTrace.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28 
29 using llvm::Optional;
30 using llvm::cl::opt;
31 
32 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
33 // RUN:   | %p/importer_test_min_max - \
34 // RUN:   | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - \
35 // RUN:   | FileCheck %s
36 
37 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
38 // RUN:   | %p/importer_test_min_max - \
39 // RUN:   | flatbuffer_to_string - \
40 // RUN:   | FileCheck --check-prefix=FB %s
41 
42 // Tests for verifying the tflite model with min/max can be imported
43 // correctly.
44 
45 // NOLINTNEXTLINE
46 static opt<std::string> inputFileName(llvm::cl::Positional,
47                                       llvm::cl::desc("<input file>"),
48                                       llvm::cl::init("-"));
49 
50 namespace mlir {
51 namespace {
InjectStatsToFullyConnected(llvm::StringRef buffer)52 Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
53     llvm::StringRef buffer) {
54   auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
55       buffer.data(), buffer.size());
56   if (nullptr == model_ptr) {
57     return llvm::None;
58   }
59   std::unique_ptr<tflite::ModelT> model(model_ptr->GetModel()->UnPack());
60 
61   // FB-LABEL:     name: "arg0",
62   // FB-NEXT:      quantization: {
63   // FB-NEXT:              min: [ -1.0 ],
64   // FB-NEXT:              max: [ 1.0 ]
65   // FB-NEXT:      }
66 
67   // FB-LABEL:     name: "arg1",
68   // FB-NEXT:            quantization: {
69   // FB-EMPTY:
70   // FB-NEXT:            }
71 
72   // FB-LABEL:     name: "tfl.fully_connected",
73   // FB-NEXT:      quantization: {
74   // FB-NEXT:        min: [ -0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0,
75   // FB-SAME:  -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0, -16.0,
76   // FB-SAME:  -17.0, -18.0, -19.0, -20.0, -21.0, -22.0, -23.0, -24.0, -25.0,
77   // FB-SAME:  -26.0, -27.0, -28.0, -29.0, -30.0, -31.0, -32.0, -33.0, -34.0,
78   // FB-SAME:  -35.0, -36.0, -37.0, -38.0, -39.0 ],
79   // FB-NEXT:        max: [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
80   // FB-SAME:  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
81   // FB-SAME:  21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
82   // FB-SAME:  32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0 ],
83   // FB-NEXT:        quantized_dimension: 1
84   // FB-NEXT:      }
85 
86   // FB-LABEL:     name: "tfl.fully_connected:1",
87   // FB-NEXT:      quantization: {
88   // FB-EMPTY:
89   // FB-NEXT:      }
90 
91   // FB-LABEL:      operators: [ {
92   // FB-NEXT:             inputs: [ 0, 1, 2 ],
93   // FB-NEXT:             outputs: [ 3, 4 ],
94   // FB-NEXT:             builtin_options_type: FullyConnectedOptions,
95   // FB-NEXT:             builtin_options: {
96   // FB-EMPTY:
97   // FB-NEXT:             }
98   // FB-NEXT:       } ],
99 
100   // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
101   // CHECK-SAME:      -> tensor<40x40xf32>
102   // CHECK:         %[[stat:.*]] = "quantfork.stats"(%arg0) {layerStats = dense<
103   // CHECK-SAME:      [-1.000000e+00, 1.000000e+00]> : tensor<2xf32>}
104   // CHECK-SAME:      : (tensor<40x37xf32>) -> tensor<40x37xf32>
105   // CHECK-NEXT:    %[[cst:.*]] = "tfl.pseudo_const"() {value = dense<
106   // CHECK-SAME:      1.000000e+00> : tensor<40xf32>} : () -> tensor<40xf32>
107   // CHECK-NEXT:    %[[fc:.*]]:2 = "tfl.fully_connected"(%[[stat]], %arg1,
108   // CHECK-NEXT:    %[[stat1:.*]] = "quantfork.stats"(%[[fc]]#0)
109   // CHECK-SAME:    {axis = 1 : i64,
110   // CHECK-SAME:      axisStats = dense<{{\[}}[-0.000000e+00, 0.000000e+00],
111   // CHECK-SAME:      [-1.000000e+00, 1.000000e+00],
112   // CHECK-SAME:      [-2.000000e+00, 2.000000e+00]
113   // CHECK-NEXT:    return %[[stat1]] : tensor<40x40xf32>
114   // CHECK-NEXT:  }
115 
116   // Find the tensors and inject the min and max to the input and output
117   for (auto& sub_graph : model->subgraphs) {
118     for (auto& op : sub_graph->operators) {
119       if (tflite::GetBuiltinCode(
120               model->operator_codes[op->opcode_index].get()) ==
121           tflite::BuiltinOperator_FULLY_CONNECTED) {
122         // inject min/max to the input and output tensors
123         auto& input_tensor = sub_graph->tensors[op->inputs[0]];
124         input_tensor->quantization->scale.clear();
125         input_tensor->quantization->zero_point.clear();
126         input_tensor->quantization->min.push_back(-1.0);
127         input_tensor->quantization->max.push_back(1.0);
128 
129         auto& output_tensor = sub_graph->tensors[op->outputs[0]];
130         auto shape = output_tensor->shape;
131         output_tensor->quantization->scale.clear();
132         output_tensor->quantization->zero_point.clear();
133         for (int i = 0; i < shape.back(); ++i) {
134           output_tensor->quantization->min.push_back(-1.0 * i);
135           output_tensor->quantization->max.push_back(1.0 * i);
136         }
137         output_tensor->quantization->quantized_dimension = shape.size() - 1;
138       }
139     }
140   }
141   return model;
142 }
143 
144 }  // namespace
145 }  // namespace mlir
146 
main(int argc,char ** argv)147 int main(int argc, char** argv) {
148   llvm::InitLLVM y(argc, argv);
149   llvm::cl::ParseCommandLineOptions(argc, argv);
150   auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
151   if (std::error_code error = file_or_err.getError()) {
152     llvm::errs() << argv[0] << ": could not open input file '" << inputFileName
153                  << "': " << error.message() << "\n";
154     return 1;
155   }
156   auto buffer = file_or_err->get();
157   auto maybe_module =
158       mlir::InjectStatsToFullyConnected(buffer->getBuffer().str());
159   if (!maybe_module.has_value()) {
160     return 1;
161   }
162   flatbuffers::FlatBufferBuilder builder;
163   flatbuffers::Offset<tflite::Model> output_model_location =
164       tflite::Model::Pack(builder, maybe_module.getValue().get());
165   tflite::FinishModelBuffer(builder, output_model_location);
166   std::string output_model_content(
167       reinterpret_cast<const char*>(builder.GetBufferPointer()),
168       builder.GetSize());
169   std::cout << output_model_content << "\n";
170   return 0;
171 }
172