1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/tf_tensor_view.h"
16
17 #include <utility>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 #include "tensorflow/core/platform/tstring.h"
25
26 namespace tflite {
27 namespace shim {
28 namespace {
29
30 using ::tensorflow::protobuf::TextFormat;
31
TEST(TfTensorView,Bool)32 TEST(TfTensorView, Bool) {
33 ::tensorflow::TensorProto tf_tensor_pb;
34 ASSERT_TRUE(TextFormat::ParseFromString(
35 R"pb(
36 dtype: DT_BOOL
37 tensor_shape {
38 dim:
39 [ { size: 3 }
40 , { size: 2 }]
41 }
42 bool_val: [ false, false, false, false, false, false ]
43 )pb",
44 &tf_tensor_pb));
45 ::tensorflow::Tensor tf_tensor;
46 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
47
48 // Test move assignment
49 auto t_premove_or = TensorView::New(&tf_tensor);
50 ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
51 auto t = std::move(t_premove_or.value());
52
53 auto tensor_data_as_vector = t.Data<bool>();
54 for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i % 5 == 0;
55
56 ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
57 ::testing::Eq("[[1 0]\n [0 0]\n [0 1]]"));
58 }
59
TEST(TfTensorView,Int32)60 TEST(TfTensorView, Int32) {
61 ::tensorflow::TensorProto tf_tensor_pb;
62 ASSERT_TRUE(TextFormat::ParseFromString(
63 R"pb(
64 dtype: DT_INT32
65 tensor_shape {
66 dim:
67 [ { size: 3 }
68 , { size: 2 }]
69 }
70 int_val: [ 0, 0, 0, 0, 0, 0 ]
71 )pb",
72 &tf_tensor_pb));
73 ::tensorflow::Tensor tf_tensor;
74 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
75
76 // Test move assignment
77 auto t_premove_or = TensorView::New(&tf_tensor);
78 ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
79 auto t = std::move(t_premove_or.value());
80
81 auto tensor_data_as_vector = t.Data<int32_t>();
82 for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i;
83
84 ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
85 ::testing::Eq("[[0 1]\n [2 3]\n [4 5]]"));
86 }
87
TEST(TfTensorView,Int64)88 TEST(TfTensorView, Int64) {
89 ::tensorflow::TensorProto tf_tensor_pb;
90 ASSERT_TRUE(TextFormat::ParseFromString(
91 R"pb(
92 dtype: DT_INT64
93 tensor_shape {
94 dim:
95 [ { size: 3 }
96 , { size: 2 }]
97 }
98 int_val: [ 0, 0, 0, 0, 0, 0 ]
99 )pb",
100 &tf_tensor_pb));
101 ::tensorflow::Tensor tf_tensor;
102 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
103 auto t_or = TensorView::New(&tf_tensor);
104 ASSERT_TRUE(t_or.ok()) << t_or.status();
105 auto& t = t_or.value();
106
107 auto tensor_data_as_vector = t.Data<int64_t>();
108 for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i;
109
110 ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
111 ::testing::Eq("[[0 1]\n [2 3]\n [4 5]]"));
112 }
113
TEST(TfTensorView,Float)114 TEST(TfTensorView, Float) {
115 ::tensorflow::TensorProto tf_tensor_pb;
116 ASSERT_TRUE(TextFormat::ParseFromString(
117 R"pb(
118 dtype: DT_FLOAT
119 tensor_shape {
120 dim:
121 [ { size: 3 }
122 , { size: 2 }]
123 }
124 float_val: [ 0, 0, 0, 0, 0, 0 ]
125 )pb",
126 &tf_tensor_pb));
127 ::tensorflow::Tensor tf_tensor;
128 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
129 auto t_or = TensorView::New(&tf_tensor);
130 ASSERT_TRUE(t_or.ok()) << t_or.status();
131 auto& t = t_or.value();
132
133 auto tensor_data_as_vector = t.Data<float>();
134 for (int i = 0; i < 3 * 2; ++i)
135 tensor_data_as_vector[i] = static_cast<float>(i) / 2.0;
136
137 ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
138 ::testing::Eq("[[0 0.5]\n [1 1.5]\n [2 2.5]]"));
139 }
140
TEST(TfTensorView,Double)141 TEST(TfTensorView, Double) {
142 ::tensorflow::TensorProto tf_tensor_pb;
143 ASSERT_TRUE(TextFormat::ParseFromString(
144 R"pb(
145 dtype: DT_DOUBLE
146 tensor_shape {
147 dim:
148 [ { size: 3 }
149 , { size: 2 }]
150 }
151 double_val: [ 0, 0, 0, 0, 0, 0 ]
152 )pb",
153 &tf_tensor_pb));
154 ::tensorflow::Tensor tf_tensor;
155 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
156 auto t_or = TensorView::New(&tf_tensor);
157 ASSERT_TRUE(t_or.ok()) << t_or.status();
158 auto& t = t_or.value();
159
160 auto tensor_data_as_vector = t.Data<double>();
161 for (int i = 0; i < 3 * 2; ++i)
162 tensor_data_as_vector[i] = static_cast<double>(i) / 2.0;
163
164 ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
165 ::testing::Eq("[[0 0.5]\n [1 1.5]\n [2 2.5]]"));
166 }
167
TEST(TfTensorView,Str)168 TEST(TfTensorView, Str) {
169 ::tensorflow::TensorProto tf_tensor_pb;
170 ASSERT_TRUE(TextFormat::ParseFromString(
171 R"pb(
172 dtype: DT_STRING
173 tensor_shape {
174 dim:
175 [ { size: 3 }
176 , { size: 2 }]
177 }
178 string_val: [ "", "", "", "", "", "" ]
179 )pb",
180 &tf_tensor_pb));
181 ::tensorflow::Tensor tf_tensor;
182 ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
183 auto t_or = TensorView::New(&tf_tensor);
184 ASSERT_TRUE(t_or.ok()) << t_or.status();
185 auto& t = t_or.value();
186
187 auto tensor_data_as_vector = t.Data<::tensorflow::tstring>();
188 tensor_data_as_vector[0] = "a";
189 tensor_data_as_vector[1] = "bc";
190 tensor_data_as_vector[2] = "def";
191 tensor_data_as_vector[3] = "g";
192 tensor_data_as_vector[4] = "hi";
193 tensor_data_as_vector[5] = "";
194
195 EXPECT_THAT(t.Data<::tensorflow::tstring>(),
196 ::testing::ElementsAre("a", "bc", "def", "g", "hi", ""));
197
198 const auto& const_tf_tensor = tf_tensor;
199 const auto const_t_or = TensorView::New(&const_tf_tensor);
200 ASSERT_TRUE(const_t_or.ok()) << const_t_or.status();
201 const auto& const_t = const_t_or.value();
202
203 EXPECT_THAT(const_t.Data<::tensorflow::tstring>(),
204 ::testing::ElementsAre("a", "bc", "def", "g", "hi", ""));
205
206 const char expectation[] = R"(
207 [["a" "bc"]
208 ["def" "g"]
209 ["hi" ""]])";
210
211 EXPECT_THAT(tf_tensor.SummarizeValue(10, true),
212 ::testing::Eq(absl::string_view(expectation).substr(1)));
213 }
214
215 } // namespace
216 } // namespace shim
217 } // namespace tflite
218