xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <ostream>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/status/status.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/interpreter.h"
27 #include "tensorflow/lite/kernels/register.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/string_type.h"
30 
operator <<(std::ostream & os,const TfLiteTensor & tensor)31 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor) {
32   std::string shape;
33   std::optional<std::string> result = tflite::ShapeToString(tensor.dims);
34   if (result.has_value()) {
35     shape = std::move(result.value());
36   } else {
37     shape = "[error: unsupported number of dimensions]";
38   }
39   return os << "tensor of shape " << shape;
40 }
41 
42 namespace tflite {
43 
ShapeToString(TfLiteIntArray * shape)44 std::optional<std::string> ShapeToString(TfLiteIntArray* shape) {
45   std::string result;
46   int* data = shape->data;
47   switch (shape->size) {
48     case 1:
49       result = absl::Substitute("Linear=[$0]", data[0]);
50       break;
51     case 2:
52       result = absl::Substitute("HW=[$0, $1]", data[0], data[1]);
53       break;
54     case 3:
55       result = absl::Substitute("HWC=[$0, $1, $2]", data[0], data[1], data[2]);
56       break;
57     case 4:
58       result = absl::Substitute("BHWC=[$0, $1, $2, $3]", data[0], data[1],
59                                 data[2], data[3]);
60       break;
61     default:
62       // This printer doesn't expect shapes of more than 4 dimensions.
63       return std::nullopt;
64   }
65   return result;
66 }
67 
CoordinateToString(TfLiteIntArray * shape,int linear)68 std::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
69                                               int linear) {
70   std::string result;
71   switch (shape->size) {
72     case 1: {
73       result = absl::Substitute("[$0]", linear);
74       break;
75     } break;
76     case 2: {
77       const int tensor_width = shape->data[1];
78       const int h_coord = linear / tensor_width;
79       const int w_coord = linear % tensor_width;
80       result = absl::Substitute("[$0, $1]", h_coord, w_coord);
81       break;
82     } break;
83     case 3: {
84       const int tensor_width = shape->data[1];
85       const int tensor_channels = shape->data[2];
86       const int h_coord = linear / (tensor_width * tensor_channels);
87       const int w_coord =
88           (linear % (tensor_width * tensor_channels)) / tensor_channels;
89       const int c_coord =
90           (linear % (tensor_width * tensor_channels)) % tensor_channels;
91       result = absl::Substitute("[$0, $1, $2]", h_coord, w_coord, c_coord);
92       break;
93     } break;
94     case 4: {
95       const int tensor_height = shape->data[1];
96       const int tensor_width = shape->data[2];
97       const int tensor_channels = shape->data[3];
98       const int b_coord =
99           linear / (tensor_height * tensor_width * tensor_channels);
100       const int h_coord =
101           (linear % (tensor_height * tensor_width * tensor_channels)) /
102           (tensor_width * tensor_channels);
103       const int w_coord =
104           ((linear % (tensor_height * tensor_width * tensor_channels)) %
105            (tensor_width * tensor_channels)) /
106           tensor_channels;
107       const int c_coord =
108           ((linear % (tensor_height * tensor_width * tensor_channels)) %
109            (tensor_width * tensor_channels)) %
110           tensor_channels;
111       result = absl::Substitute("[$0, $1, $2, $3]", b_coord, h_coord, w_coord,
112                                 c_coord);
113       break;
114     }
115     default:
116       // This printer doesn't expect shapes of more than 4 dimensions.
117       return std::nullopt;
118   }
119   return result;
120 }
121 
122 // Builds interpreter for a model, allocates tensors.
BuildInterpreter(const Model * model,std::unique_ptr<Interpreter> * interpreter)123 absl::Status BuildInterpreter(const Model* model,
124                               std::unique_ptr<Interpreter>* interpreter) {
125   TfLiteStatus status =
126       InterpreterBuilder(model, ops::builtin::BuiltinOpResolver())(interpreter);
127   if (status != kTfLiteOk || !*interpreter) {
128     return absl::InternalError(
129         "Failed to initialize interpreter with model binary.");
130   }
131   return absl::OkStatus();
132 }
133 
AllocateTensors(std::unique_ptr<Interpreter> * interpreter)134 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter) {
135   if ((*interpreter)->AllocateTensors() != kTfLiteOk) {
136     return absl::InternalError("Failed to allocate tensors.");
137   }
138   return absl::OkStatus();
139 }
140 
ModifyGraphWithDelegate(std::unique_ptr<Interpreter> * interpreter,TfLiteDelegate * delegate)141 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
142                                      TfLiteDelegate* delegate) {
143   if ((*interpreter)->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
144     return absl::InternalError("Failed to modify graph with delegate.");
145   }
146   return absl::OkStatus();
147 }
148 
InitializeInputs(int left,int right,std::unique_ptr<Interpreter> * interpreter)149 void InitializeInputs(int left, int right,
150                       std::unique_ptr<Interpreter>* interpreter) {
151   for (int id : (*interpreter)->inputs()) {
152     float* input_data = (*interpreter)->typed_tensor<float>(id);
153     int input_size = (*interpreter)->input_tensor(id)->bytes;
154     for (int i = 0; i < input_size; i++) {
155       input_data[i] = left + i % right;
156     }
157   }
158 }
159 
Invoke(std::unique_ptr<Interpreter> * interpreter)160 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter) {
161   if ((*interpreter)->Invoke() != kTfLiteOk) {
162     return absl::InternalError("Failed during inference.");
163   }
164   return absl::OkStatus();
165 }
166 
operator <<(std::ostream & os,const TestParams & param)167 std::ostream& operator<<(std::ostream& os, const TestParams& param) {
168   return os << param.name;
169 }
170 
171 }  // namespace tflite
172