xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/tf_tensor_view_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/tf_tensor_view.h"
16 
17 #include <utility>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 #include "tensorflow/core/platform/tstring.h"
25 
26 namespace tflite {
27 namespace shim {
28 namespace {
29 
30 using ::tensorflow::protobuf::TextFormat;
31 
TEST(TfTensorView,Bool)32 TEST(TfTensorView, Bool) {
33   ::tensorflow::TensorProto tf_tensor_pb;
34   ASSERT_TRUE(TextFormat::ParseFromString(
35       R"pb(
36         dtype: DT_BOOL
37         tensor_shape {
38           dim:
39           [ { size: 3 }
40             , { size: 2 }]
41         }
42         bool_val: [ false, false, false, false, false, false ]
43       )pb",
44       &tf_tensor_pb));
45   ::tensorflow::Tensor tf_tensor;
46   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
47 
48   // Test move assignment
49   auto t_premove_or = TensorView::New(&tf_tensor);
50   ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
51   auto t = std::move(t_premove_or.value());
52 
53   auto tensor_data_as_vector = t.Data<bool>();
54   for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i % 5 == 0;
55 
56   ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
57               ::testing::Eq("[[1 0]\n [0 0]\n [0 1]]"));
58 }
59 
TEST(TfTensorView,Int32)60 TEST(TfTensorView, Int32) {
61   ::tensorflow::TensorProto tf_tensor_pb;
62   ASSERT_TRUE(TextFormat::ParseFromString(
63       R"pb(
64         dtype: DT_INT32
65         tensor_shape {
66           dim:
67           [ { size: 3 }
68             , { size: 2 }]
69         }
70         int_val: [ 0, 0, 0, 0, 0, 0 ]
71       )pb",
72       &tf_tensor_pb));
73   ::tensorflow::Tensor tf_tensor;
74   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
75 
76   // Test move assignment
77   auto t_premove_or = TensorView::New(&tf_tensor);
78   ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
79   auto t = std::move(t_premove_or.value());
80 
81   auto tensor_data_as_vector = t.Data<int32_t>();
82   for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i;
83 
84   ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
85               ::testing::Eq("[[0 1]\n [2 3]\n [4 5]]"));
86 }
87 
TEST(TfTensorView,Int64)88 TEST(TfTensorView, Int64) {
89   ::tensorflow::TensorProto tf_tensor_pb;
90   ASSERT_TRUE(TextFormat::ParseFromString(
91       R"pb(
92         dtype: DT_INT64
93         tensor_shape {
94           dim:
95           [ { size: 3 }
96             , { size: 2 }]
97         }
98         int_val: [ 0, 0, 0, 0, 0, 0 ]
99       )pb",
100       &tf_tensor_pb));
101   ::tensorflow::Tensor tf_tensor;
102   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
103   auto t_or = TensorView::New(&tf_tensor);
104   ASSERT_TRUE(t_or.ok()) << t_or.status();
105   auto& t = t_or.value();
106 
107   auto tensor_data_as_vector = t.Data<int64_t>();
108   for (int i = 0; i < 3 * 2; ++i) tensor_data_as_vector[i] = i;
109 
110   ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
111               ::testing::Eq("[[0 1]\n [2 3]\n [4 5]]"));
112 }
113 
TEST(TfTensorView,Float)114 TEST(TfTensorView, Float) {
115   ::tensorflow::TensorProto tf_tensor_pb;
116   ASSERT_TRUE(TextFormat::ParseFromString(
117       R"pb(
118         dtype: DT_FLOAT
119         tensor_shape {
120           dim:
121           [ { size: 3 }
122             , { size: 2 }]
123         }
124         float_val: [ 0, 0, 0, 0, 0, 0 ]
125       )pb",
126       &tf_tensor_pb));
127   ::tensorflow::Tensor tf_tensor;
128   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
129   auto t_or = TensorView::New(&tf_tensor);
130   ASSERT_TRUE(t_or.ok()) << t_or.status();
131   auto& t = t_or.value();
132 
133   auto tensor_data_as_vector = t.Data<float>();
134   for (int i = 0; i < 3 * 2; ++i)
135     tensor_data_as_vector[i] = static_cast<float>(i) / 2.0;
136 
137   ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
138               ::testing::Eq("[[0 0.5]\n [1 1.5]\n [2 2.5]]"));
139 }
140 
TEST(TfTensorView,Double)141 TEST(TfTensorView, Double) {
142   ::tensorflow::TensorProto tf_tensor_pb;
143   ASSERT_TRUE(TextFormat::ParseFromString(
144       R"pb(
145         dtype: DT_DOUBLE
146         tensor_shape {
147           dim:
148           [ { size: 3 }
149             , { size: 2 }]
150         }
151         double_val: [ 0, 0, 0, 0, 0, 0 ]
152       )pb",
153       &tf_tensor_pb));
154   ::tensorflow::Tensor tf_tensor;
155   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
156   auto t_or = TensorView::New(&tf_tensor);
157   ASSERT_TRUE(t_or.ok()) << t_or.status();
158   auto& t = t_or.value();
159 
160   auto tensor_data_as_vector = t.Data<double>();
161   for (int i = 0; i < 3 * 2; ++i)
162     tensor_data_as_vector[i] = static_cast<double>(i) / 2.0;
163 
164   ASSERT_THAT(tf_tensor.SummarizeValue(10, true),
165               ::testing::Eq("[[0 0.5]\n [1 1.5]\n [2 2.5]]"));
166 }
167 
TEST(TfTensorView,Str)168 TEST(TfTensorView, Str) {
169   ::tensorflow::TensorProto tf_tensor_pb;
170   ASSERT_TRUE(TextFormat::ParseFromString(
171       R"pb(
172         dtype: DT_STRING
173         tensor_shape {
174           dim:
175           [ { size: 3 }
176             , { size: 2 }]
177         }
178         string_val: [ "", "", "", "", "", "" ]
179       )pb",
180       &tf_tensor_pb));
181   ::tensorflow::Tensor tf_tensor;
182   ASSERT_TRUE(tf_tensor.FromProto(tf_tensor_pb));
183   auto t_or = TensorView::New(&tf_tensor);
184   ASSERT_TRUE(t_or.ok()) << t_or.status();
185   auto& t = t_or.value();
186 
187   auto tensor_data_as_vector = t.Data<::tensorflow::tstring>();
188   tensor_data_as_vector[0] = "a";
189   tensor_data_as_vector[1] = "bc";
190   tensor_data_as_vector[2] = "def";
191   tensor_data_as_vector[3] = "g";
192   tensor_data_as_vector[4] = "hi";
193   tensor_data_as_vector[5] = "";
194 
195   EXPECT_THAT(t.Data<::tensorflow::tstring>(),
196               ::testing::ElementsAre("a", "bc", "def", "g", "hi", ""));
197 
198   const auto& const_tf_tensor = tf_tensor;
199   const auto const_t_or = TensorView::New(&const_tf_tensor);
200   ASSERT_TRUE(const_t_or.ok()) << const_t_or.status();
201   const auto& const_t = const_t_or.value();
202 
203   EXPECT_THAT(const_t.Data<::tensorflow::tstring>(),
204               ::testing::ElementsAre("a", "bc", "def", "g", "hi", ""));
205 
206   const char expectation[] = R"(
207 [["a" "bc"]
208  ["def" "g"]
209  ["hi" ""]])";
210 
211   EXPECT_THAT(tf_tensor.SummarizeValue(10, true),
212               ::testing::Eq(absl::string_view(expectation).substr(1)));
213 }
214 
215 }  // namespace
216 }  // namespace shim
217 }  // namespace tflite
218