xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/test_util.h"
16 
17 #include <string>
18 
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/string_util.h"
23 
24 namespace tflite {
25 
26 using std::size_t;
27 
get()28 TfLiteTensor* UniqueTfLiteTensor::get() { return tensor_; }
29 
operator *()30 TfLiteTensor& UniqueTfLiteTensor::operator*() { return *tensor_; }
31 
operator ->()32 TfLiteTensor* UniqueTfLiteTensor::operator->() { return tensor_; }
33 
operator ->() const34 const TfLiteTensor* UniqueTfLiteTensor::operator->() const { return tensor_; }
35 
reset(TfLiteTensor * tensor)36 void UniqueTfLiteTensor::reset(TfLiteTensor* tensor) { tensor_ = tensor; }
37 
~UniqueTfLiteTensor()38 UniqueTfLiteTensor::~UniqueTfLiteTensor() { TfLiteTensorFree(tensor_); }
39 
40 namespace {
41 
42 template <typename T>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)43 std::string TensorValueToString(const ::TfLiteTensor* tensor,
44                                 const size_t idx) {
45   TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<T>());
46   const T* val_array = reinterpret_cast<const T*>(tensor->data.raw);
47   return std::to_string(val_array[idx]);
48 }
49 
50 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)51 std::string TensorValueToString<bool>(const ::TfLiteTensor* tensor,
52                                       const size_t idx) {
53   TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<bool>());
54   const bool* val_array = reinterpret_cast<const bool*>(tensor->data.raw);
55   return val_array[idx] ? "1" : "0";
56 }
57 
58 template <typename FloatType>
TensorValueToStringFloat(const::TfLiteTensor * tensor,const size_t idx)59 std::string TensorValueToStringFloat(const ::TfLiteTensor* tensor,
60                                      const size_t idx) {
61   TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<FloatType>());
62   const FloatType* val_array =
63       reinterpret_cast<const FloatType*>(tensor->data.raw);
64   std::stringstream ss;
65   ss << val_array[idx];
66   return std::string(ss.str().data(), ss.str().length());
67 }
68 
69 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)70 std::string TensorValueToString<float>(const ::TfLiteTensor* tensor,
71                                        const size_t idx) {
72   return TensorValueToStringFloat<float>(tensor, idx);
73 }
74 
75 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)76 std::string TensorValueToString<double>(const ::TfLiteTensor* tensor,
77                                         const size_t idx) {
78   return TensorValueToStringFloat<double>(tensor, idx);
79 }
80 
81 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)82 std::string TensorValueToString<StringRef>(const ::TfLiteTensor* tensor,
83                                            const size_t idx) {
84   TFLITE_DCHECK_EQ(tensor->type, kTfLiteString);
85   const auto ref = ::tflite::GetString(tensor, idx);
86   return std::string(ref.str, ref.len);
87 }
88 
TfliteTensorDebugStringImpl(const::TfLiteTensor * tensor,const size_t axis,const size_t max_values,size_t * start_idx)89 std::string TfliteTensorDebugStringImpl(const ::TfLiteTensor* tensor,
90                                         const size_t axis,
91                                         const size_t max_values,
92                                         size_t* start_idx) {
93   const size_t dim_size = tensor->dims->data[axis];
94   if (axis == tensor->dims->size - 1) {
95     std::vector<std::string> ret_list;
96     ret_list.reserve(dim_size);
97     int idx = *start_idx;
98     for (int i = 0; i < dim_size && idx < max_values; ++i, ++idx) {
99       std::string val_str;
100       switch (tensor->type) {
101         case kTfLiteBool: {
102           val_str = TensorValueToString<bool>(tensor, idx);
103           break;
104         }
105         case kTfLiteUInt8: {
106           val_str = TensorValueToString<uint8_t>(tensor, idx);
107           break;
108         }
109         case kTfLiteInt8: {
110           val_str = TensorValueToString<int8_t>(tensor, idx);
111           break;
112         }
113         case kTfLiteInt16: {
114           val_str = TensorValueToString<int16_t>(tensor, idx);
115           break;
116         }
117         case kTfLiteInt32: {
118           val_str = TensorValueToString<int32_t>(tensor, idx);
119           break;
120         }
121         case kTfLiteInt64: {
122           val_str = TensorValueToString<int64_t>(tensor, idx);
123           break;
124         }
125         case kTfLiteString: {
126           val_str = TensorValueToString<StringRef>(tensor, idx);
127           break;
128         }
129         case kTfLiteFloat32: {
130           val_str = TensorValueToString<float>(tensor, idx);
131           break;
132         }
133         case kTfLiteFloat64: {
134           val_str = TensorValueToString<double>(tensor, idx);
135           break;
136         }
137         default: {
138           val_str = "unsupported_type";
139         }
140       }
141       ret_list.push_back(val_str);
142     }
143     *start_idx = idx;
144     if (idx == max_values && ret_list.size() < dim_size) {
145       ret_list.push_back("...");
146     }
147     return absl::StrCat("[", absl::StrJoin(ret_list, ", "), "]");
148   } else {
149     std::vector<std::string> ret_list;
150     ret_list.reserve(dim_size);
151     for (int i = 0; i < dim_size && *start_idx < max_values; ++i) {
152       ret_list.push_back(
153           TfliteTensorDebugStringImpl(tensor, axis + 1, max_values, start_idx));
154     }
155     return absl::StrCat("[", absl::StrJoin(ret_list, ", "), "]");
156   }
157 }
158 
159 }  // namespace
160 
TfliteTensorDebugString(const::TfLiteTensor * tensor,const size_t max_values)161 std::string TfliteTensorDebugString(const ::TfLiteTensor* tensor,
162                                     const size_t max_values) {
163   if (tensor->dims->size == 0) return "";
164   size_t start_idx = 0;
165   return TfliteTensorDebugStringImpl(tensor, 0, max_values, &start_idx);
166 }
167 
NumTotalFromShape(const std::initializer_list<int> & shape)168 size_t NumTotalFromShape(const std::initializer_list<int>& shape) {
169   size_t num_total;
170   if (shape.size() > 0)
171     num_total = 1;
172   else
173     num_total = 0;
174   for (const int dim : shape) num_total *= dim;
175   return num_total;
176 }
177 
178 template <>
PopulateTfLiteTensorValue(const std::initializer_list<std::string> values,TfLiteTensor * tensor)179 void PopulateTfLiteTensorValue<std::string>(
180     const std::initializer_list<std::string> values, TfLiteTensor* tensor) {
181   tflite::DynamicBuffer buf;
182   for (const std::string& s : values) {
183     buf.AddString(s.data(), s.length());
184   }
185   buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
186 }
187 
188 }  // namespace tflite
189