1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
17
18 #include <memory>
19 #include <optional>
20 #include <ostream>
21 #include <string>
22 #include <utility>
23
24 #include "absl/status/status.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/interpreter.h"
27 #include "tensorflow/lite/kernels/register.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/string_type.h"
30
operator <<(std::ostream & os,const TfLiteTensor & tensor)31 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor) {
32 std::string shape;
33 std::optional<std::string> result = tflite::ShapeToString(tensor.dims);
34 if (result.has_value()) {
35 shape = std::move(result.value());
36 } else {
37 shape = "[error: unsupported number of dimensions]";
38 }
39 return os << "tensor of shape " << shape;
40 }
41
42 namespace tflite {
43
ShapeToString(TfLiteIntArray * shape)44 std::optional<std::string> ShapeToString(TfLiteIntArray* shape) {
45 std::string result;
46 int* data = shape->data;
47 switch (shape->size) {
48 case 1:
49 result = absl::Substitute("Linear=[$0]", data[0]);
50 break;
51 case 2:
52 result = absl::Substitute("HW=[$0, $1]", data[0], data[1]);
53 break;
54 case 3:
55 result = absl::Substitute("HWC=[$0, $1, $2]", data[0], data[1], data[2]);
56 break;
57 case 4:
58 result = absl::Substitute("BHWC=[$0, $1, $2, $3]", data[0], data[1],
59 data[2], data[3]);
60 break;
61 default:
62 // This printer doesn't expect shapes of more than 4 dimensions.
63 return std::nullopt;
64 }
65 return result;
66 }
67
CoordinateToString(TfLiteIntArray * shape,int linear)68 std::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
69 int linear) {
70 std::string result;
71 switch (shape->size) {
72 case 1: {
73 result = absl::Substitute("[$0]", linear);
74 break;
75 } break;
76 case 2: {
77 const int tensor_width = shape->data[1];
78 const int h_coord = linear / tensor_width;
79 const int w_coord = linear % tensor_width;
80 result = absl::Substitute("[$0, $1]", h_coord, w_coord);
81 break;
82 } break;
83 case 3: {
84 const int tensor_width = shape->data[1];
85 const int tensor_channels = shape->data[2];
86 const int h_coord = linear / (tensor_width * tensor_channels);
87 const int w_coord =
88 (linear % (tensor_width * tensor_channels)) / tensor_channels;
89 const int c_coord =
90 (linear % (tensor_width * tensor_channels)) % tensor_channels;
91 result = absl::Substitute("[$0, $1, $2]", h_coord, w_coord, c_coord);
92 break;
93 } break;
94 case 4: {
95 const int tensor_height = shape->data[1];
96 const int tensor_width = shape->data[2];
97 const int tensor_channels = shape->data[3];
98 const int b_coord =
99 linear / (tensor_height * tensor_width * tensor_channels);
100 const int h_coord =
101 (linear % (tensor_height * tensor_width * tensor_channels)) /
102 (tensor_width * tensor_channels);
103 const int w_coord =
104 ((linear % (tensor_height * tensor_width * tensor_channels)) %
105 (tensor_width * tensor_channels)) /
106 tensor_channels;
107 const int c_coord =
108 ((linear % (tensor_height * tensor_width * tensor_channels)) %
109 (tensor_width * tensor_channels)) %
110 tensor_channels;
111 result = absl::Substitute("[$0, $1, $2, $3]", b_coord, h_coord, w_coord,
112 c_coord);
113 break;
114 }
115 default:
116 // This printer doesn't expect shapes of more than 4 dimensions.
117 return std::nullopt;
118 }
119 return result;
120 }
121
122 // Builds interpreter for a model, allocates tensors.
BuildInterpreter(const Model * model,std::unique_ptr<Interpreter> * interpreter)123 absl::Status BuildInterpreter(const Model* model,
124 std::unique_ptr<Interpreter>* interpreter) {
125 TfLiteStatus status =
126 InterpreterBuilder(model, ops::builtin::BuiltinOpResolver())(interpreter);
127 if (status != kTfLiteOk || !*interpreter) {
128 return absl::InternalError(
129 "Failed to initialize interpreter with model binary.");
130 }
131 return absl::OkStatus();
132 }
133
AllocateTensors(std::unique_ptr<Interpreter> * interpreter)134 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter) {
135 if ((*interpreter)->AllocateTensors() != kTfLiteOk) {
136 return absl::InternalError("Failed to allocate tensors.");
137 }
138 return absl::OkStatus();
139 }
140
ModifyGraphWithDelegate(std::unique_ptr<Interpreter> * interpreter,TfLiteDelegate * delegate)141 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
142 TfLiteDelegate* delegate) {
143 if ((*interpreter)->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
144 return absl::InternalError("Failed to modify graph with delegate.");
145 }
146 return absl::OkStatus();
147 }
148
InitializeInputs(int left,int right,std::unique_ptr<Interpreter> * interpreter)149 void InitializeInputs(int left, int right,
150 std::unique_ptr<Interpreter>* interpreter) {
151 for (int id : (*interpreter)->inputs()) {
152 float* input_data = (*interpreter)->typed_tensor<float>(id);
153 int input_size = (*interpreter)->input_tensor(id)->bytes;
154 for (int i = 0; i < input_size; i++) {
155 input_data[i] = left + i % right;
156 }
157 }
158 }
159
Invoke(std::unique_ptr<Interpreter> * interpreter)160 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter) {
161 if ((*interpreter)->Invoke() != kTfLiteOk) {
162 return absl::InternalError("Failed during inference.");
163 }
164 return absl::OkStatus();
165 }
166
operator <<(std::ostream & os,const TestParams & param)167 std::ostream& operator<<(std::ostream& os, const TestParams& param) {
168 return os << param.name;
169 }
170
171 } // namespace tflite
172