1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/test_util.h"
16
17 #include <string>
18
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/string_util.h"
23
24 namespace tflite {
25
26 using std::size_t;
27
get()28 TfLiteTensor* UniqueTfLiteTensor::get() { return tensor_; }
29
operator *()30 TfLiteTensor& UniqueTfLiteTensor::operator*() { return *tensor_; }
31
operator ->()32 TfLiteTensor* UniqueTfLiteTensor::operator->() { return tensor_; }
33
operator ->() const34 const TfLiteTensor* UniqueTfLiteTensor::operator->() const { return tensor_; }
35
reset(TfLiteTensor * tensor)36 void UniqueTfLiteTensor::reset(TfLiteTensor* tensor) { tensor_ = tensor; }
37
~UniqueTfLiteTensor()38 UniqueTfLiteTensor::~UniqueTfLiteTensor() { TfLiteTensorFree(tensor_); }
39
40 namespace {
41
42 template <typename T>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)43 std::string TensorValueToString(const ::TfLiteTensor* tensor,
44 const size_t idx) {
45 TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<T>());
46 const T* val_array = reinterpret_cast<const T*>(tensor->data.raw);
47 return std::to_string(val_array[idx]);
48 }
49
50 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)51 std::string TensorValueToString<bool>(const ::TfLiteTensor* tensor,
52 const size_t idx) {
53 TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<bool>());
54 const bool* val_array = reinterpret_cast<const bool*>(tensor->data.raw);
55 return val_array[idx] ? "1" : "0";
56 }
57
58 template <typename FloatType>
TensorValueToStringFloat(const::TfLiteTensor * tensor,const size_t idx)59 std::string TensorValueToStringFloat(const ::TfLiteTensor* tensor,
60 const size_t idx) {
61 TFLITE_DCHECK_EQ(tensor->type, ::tflite::typeToTfLiteType<FloatType>());
62 const FloatType* val_array =
63 reinterpret_cast<const FloatType*>(tensor->data.raw);
64 std::stringstream ss;
65 ss << val_array[idx];
66 return std::string(ss.str().data(), ss.str().length());
67 }
68
69 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)70 std::string TensorValueToString<float>(const ::TfLiteTensor* tensor,
71 const size_t idx) {
72 return TensorValueToStringFloat<float>(tensor, idx);
73 }
74
75 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)76 std::string TensorValueToString<double>(const ::TfLiteTensor* tensor,
77 const size_t idx) {
78 return TensorValueToStringFloat<double>(tensor, idx);
79 }
80
81 template <>
TensorValueToString(const::TfLiteTensor * tensor,const size_t idx)82 std::string TensorValueToString<StringRef>(const ::TfLiteTensor* tensor,
83 const size_t idx) {
84 TFLITE_DCHECK_EQ(tensor->type, kTfLiteString);
85 const auto ref = ::tflite::GetString(tensor, idx);
86 return std::string(ref.str, ref.len);
87 }
88
TfliteTensorDebugStringImpl(const::TfLiteTensor * tensor,const size_t axis,const size_t max_values,size_t * start_idx)89 std::string TfliteTensorDebugStringImpl(const ::TfLiteTensor* tensor,
90 const size_t axis,
91 const size_t max_values,
92 size_t* start_idx) {
93 const size_t dim_size = tensor->dims->data[axis];
94 if (axis == tensor->dims->size - 1) {
95 std::vector<std::string> ret_list;
96 ret_list.reserve(dim_size);
97 int idx = *start_idx;
98 for (int i = 0; i < dim_size && idx < max_values; ++i, ++idx) {
99 std::string val_str;
100 switch (tensor->type) {
101 case kTfLiteBool: {
102 val_str = TensorValueToString<bool>(tensor, idx);
103 break;
104 }
105 case kTfLiteUInt8: {
106 val_str = TensorValueToString<uint8_t>(tensor, idx);
107 break;
108 }
109 case kTfLiteInt8: {
110 val_str = TensorValueToString<int8_t>(tensor, idx);
111 break;
112 }
113 case kTfLiteInt16: {
114 val_str = TensorValueToString<int16_t>(tensor, idx);
115 break;
116 }
117 case kTfLiteInt32: {
118 val_str = TensorValueToString<int32_t>(tensor, idx);
119 break;
120 }
121 case kTfLiteInt64: {
122 val_str = TensorValueToString<int64_t>(tensor, idx);
123 break;
124 }
125 case kTfLiteString: {
126 val_str = TensorValueToString<StringRef>(tensor, idx);
127 break;
128 }
129 case kTfLiteFloat32: {
130 val_str = TensorValueToString<float>(tensor, idx);
131 break;
132 }
133 case kTfLiteFloat64: {
134 val_str = TensorValueToString<double>(tensor, idx);
135 break;
136 }
137 default: {
138 val_str = "unsupported_type";
139 }
140 }
141 ret_list.push_back(val_str);
142 }
143 *start_idx = idx;
144 if (idx == max_values && ret_list.size() < dim_size) {
145 ret_list.push_back("...");
146 }
147 return absl::StrCat("[", absl::StrJoin(ret_list, ", "), "]");
148 } else {
149 std::vector<std::string> ret_list;
150 ret_list.reserve(dim_size);
151 for (int i = 0; i < dim_size && *start_idx < max_values; ++i) {
152 ret_list.push_back(
153 TfliteTensorDebugStringImpl(tensor, axis + 1, max_values, start_idx));
154 }
155 return absl::StrCat("[", absl::StrJoin(ret_list, ", "), "]");
156 }
157 }
158
159 } // namespace
160
TfliteTensorDebugString(const::TfLiteTensor * tensor,const size_t max_values)161 std::string TfliteTensorDebugString(const ::TfLiteTensor* tensor,
162 const size_t max_values) {
163 if (tensor->dims->size == 0) return "";
164 size_t start_idx = 0;
165 return TfliteTensorDebugStringImpl(tensor, 0, max_values, &start_idx);
166 }
167
NumTotalFromShape(const std::initializer_list<int> & shape)168 size_t NumTotalFromShape(const std::initializer_list<int>& shape) {
169 size_t num_total;
170 if (shape.size() > 0)
171 num_total = 1;
172 else
173 num_total = 0;
174 for (const int dim : shape) num_total *= dim;
175 return num_total;
176 }
177
178 template <>
PopulateTfLiteTensorValue(const std::initializer_list<std::string> values,TfLiteTensor * tensor)179 void PopulateTfLiteTensorValue<std::string>(
180 const std::initializer_list<std::string> values, TfLiteTensor* tensor) {
181 tflite::DynamicBuffer buf;
182 for (const std::string& s : values) {
183 buf.AddString(s.data(), s.length());
184 }
185 buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
186 }
187
188 } // namespace tflite
189