1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <complex>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAreArray;
30
31 class CastOpModel : public SingleOpModel {
32 public:
CastOpModel(const TensorData & input,const TensorData & output)33 CastOpModel(const TensorData& input, const TensorData& output) {
34 input_ = AddInput(input);
35 output_ = AddOutput(output);
36 SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions,
37 CreateCastOptions(builder_).Union());
38 BuildInterpreter({GetShape(input_)});
39 }
40
input() const41 int input() const { return input_; }
output() const42 int output() const { return output_; }
43
44 protected:
45 int input_;
46 int output_;
47 };
48
TEST(CastOpModel,CastInt16ToFloat)49 TEST(CastOpModel, CastInt16ToFloat) {
50 CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
51 m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
52 ASSERT_EQ(m.Invoke(), kTfLiteOk);
53 EXPECT_THAT(m.ExtractVector<float>(m.output()),
54 ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
55 }
56
TEST(CastOpModel,CastInt16ToInt32)57 TEST(CastOpModel, CastInt16ToInt32) {
58 CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
59 m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
60 ASSERT_EQ(m.Invoke(), kTfLiteOk);
61 EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
62 ElementsAreArray({100, 200, 300, 400, 500, 600}));
63 }
64
TEST(CastOpModel,CastInt32ToFloat)65 TEST(CastOpModel, CastInt32ToFloat) {
66 CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
67 m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
68 ASSERT_EQ(m.Invoke(), kTfLiteOk);
69 EXPECT_THAT(m.ExtractVector<float>(m.output()),
70 ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
71 }
72
TEST(CastOpModel,CastFloatToInt32)73 TEST(CastOpModel, CastFloatToInt32) {
74 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
75 m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
76 ASSERT_EQ(m.Invoke(), kTfLiteOk);
77 EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
78 ElementsAreArray({100, 20, 3, 0, 0, 1}));
79 }
80
TEST(CastOpModel,CastFloatToInt16)81 TEST(CastOpModel, CastFloatToInt16) {
82 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
83 m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
84 ASSERT_EQ(m.Invoke(), kTfLiteOk);
85 EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
86 ElementsAreArray({100, 20, 3, 0, 0, 1}));
87 }
88
TEST(CastOpModel,CastInt64ToFloat)89 TEST(CastOpModel, CastInt64ToFloat) {
90 CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
91 m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
92 ASSERT_EQ(m.Invoke(), kTfLiteOk);
93 EXPECT_THAT(m.ExtractVector<float>(m.output()),
94 ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
95 }
96
TEST(CastOpModel,CastFloatToInt64)97 TEST(CastOpModel, CastFloatToInt64) {
98 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}});
99 m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
100 ASSERT_EQ(m.Invoke(), kTfLiteOk);
101 EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
102 ElementsAreArray({100, 20, 3, 0, 0, 1}));
103 }
104
TEST(CastOpModel,CastFloatToBool)105 TEST(CastOpModel, CastFloatToBool) {
106 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}});
107 m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f});
108 ASSERT_EQ(m.Invoke(), kTfLiteOk);
109 EXPECT_THAT(m.ExtractVector<bool>(m.output()),
110 ElementsAreArray({true, true, false, true, true, true}));
111 }
112
TEST(CastOpModel,CastBoolToFloat)113 TEST(CastOpModel, CastBoolToFloat) {
114 CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
115 m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true});
116 ASSERT_EQ(m.Invoke(), kTfLiteOk);
117 EXPECT_THAT(m.ExtractVector<float>(m.output()),
118 ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
119 }
120
TEST(CastOpModel,CastFloatToUInt8)121 TEST(CastOpModel, CastFloatToUInt8) {
122 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
123 m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
124 ASSERT_EQ(m.Invoke(), kTfLiteOk);
125 EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
126 ElementsAreArray({100, 1, 0, 0, 1, 1}));
127 }
128
TEST(CastOpModel,CastUInt8ToFloat)129 TEST(CastOpModel, CastUInt8ToFloat) {
130 CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
131 m.PopulateTensor<uint8_t>(m.input(), {123, 0, 1, 2, 3, 4});
132 ASSERT_EQ(m.Invoke(), kTfLiteOk);
133 EXPECT_THAT(m.ExtractVector<float>(m.output()),
134 ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
135 }
136
TEST(CastOpModel,CastFloatToUInt16)137 TEST(CastOpModel, CastFloatToUInt16) {
138 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT16, {3, 2}});
139 m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
140 ASSERT_EQ(m.Invoke(), kTfLiteOk);
141 EXPECT_THAT(m.ExtractVector<uint16_t>(m.output()),
142 ElementsAreArray({100, 1, 0, 0, 1, 1}));
143 }
144
TEST(CastOpModel,CastUInt16ToFloat)145 TEST(CastOpModel, CastUInt16ToFloat) {
146 CastOpModel m({TensorType_UINT16, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
147 m.PopulateTensor<uint16_t>(m.input(), {123, 0, 1, 2, 3, 4});
148 ASSERT_EQ(m.Invoke(), kTfLiteOk);
149 EXPECT_THAT(m.ExtractVector<float>(m.output()),
150 ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
151 }
152
TEST(CastOpModel,CastInt32ToUInt8)153 TEST(CastOpModel, CastInt32ToUInt8) {
154 CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
155 m.PopulateTensor<int32_t>(m.input(), {100, 1, 200, 2, 255, 3});
156 ASSERT_EQ(m.Invoke(), kTfLiteOk);
157 EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
158 ElementsAreArray({100, 1, 200, 2, 255, 3}));
159 }
160
TEST(CastOpModel,CastUInt8ToInt32)161 TEST(CastOpModel, CastUInt8ToInt32) {
162 CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}});
163 m.PopulateTensor<uint8_t>(m.input(), {100, 1, 200, 2, 255, 3});
164 ASSERT_EQ(m.Invoke(), kTfLiteOk);
165 EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
166 ElementsAreArray({100, 1, 200, 2, 255, 3}));
167 }
168
TEST(CastOpModel,CastComplex64ToFloat)169 TEST(CastOpModel, CastComplex64ToFloat) {
170 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
171 m.PopulateTensor<std::complex<float>>(
172 m.input(),
173 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
174 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
175 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
176 ASSERT_EQ(m.Invoke(), kTfLiteOk);
177 EXPECT_THAT(m.ExtractVector<float>(m.output()),
178 ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
179 }
180
TEST(CastOpModel,CastFloatToComplex64)181 TEST(CastOpModel, CastFloatToComplex64) {
182 CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
183 m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
184 ASSERT_EQ(m.Invoke(), kTfLiteOk);
185 EXPECT_THAT(
186 m.ExtractVector<std::complex<float>>(m.output()),
187 ElementsAreArray(
188 {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
189 std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
190 std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
191 }
192
TEST(CastOpModel,CastComplex64ToInt)193 TEST(CastOpModel, CastComplex64ToInt) {
194 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
195 m.PopulateTensor<std::complex<float>>(
196 m.input(),
197 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
198 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
199 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
200 ASSERT_EQ(m.Invoke(), kTfLiteOk);
201 EXPECT_THAT(m.ExtractVector<int>(m.output()),
202 ElementsAreArray({1, 2, 3, 4, 5, 6}));
203 }
204
TEST(CastOpModel,CastIntToComplex64)205 TEST(CastOpModel, CastIntToComplex64) {
206 CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
207 m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
208 ASSERT_EQ(m.Invoke(), kTfLiteOk);
209 EXPECT_THAT(
210 m.ExtractVector<std::complex<float>>(m.output()),
211 ElementsAreArray(
212 {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
213 std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
214 std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
215 }
216
TEST(CastOpModel,CastComplex64ToComplex64)217 TEST(CastOpModel, CastComplex64ToComplex64) {
218 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
219 m.PopulateTensor<std::complex<float>>(
220 m.input(),
221 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
222 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
223 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
224 ASSERT_EQ(m.Invoke(), kTfLiteOk);
225 EXPECT_THAT(
226 m.ExtractVector<std::complex<float>>(m.output()),
227 ElementsAreArray(
228 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
229 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
230 std::complex<float>(5.0f, 15.0f),
231 std::complex<float>(6.0f, 16.0f)}));
232 }
233
TEST(CastOpModel,CastUInt32ToInt32)234 TEST(CastOpModel, CastUInt32ToInt32) {
235 CastOpModel m({TensorType_UINT32, {2, 3}}, {TensorType_INT32, {2, 3}});
236 m.PopulateTensor<uint32_t>(m.input(), {100, 200, 300, 400, 500, 600});
237 ASSERT_EQ(m.Invoke(), kTfLiteOk);
238 EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
239 ElementsAreArray({100, 200, 300, 400, 500, 600}));
240 }
241
TEST(CastOpModel,CastInt32ToUInt32)242 TEST(CastOpModel, CastInt32ToUInt32) {
243 CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_UINT32, {2, 3}});
244 m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
245 ASSERT_EQ(m.Invoke(), kTfLiteOk);
246 EXPECT_THAT(m.ExtractVector<uint32_t>(m.output()),
247 ElementsAreArray({100, 200, 300, 400, 500, 600}));
248 }
249
TEST(CastOpModel,CastUInt8ToInt8)250 TEST(CastOpModel, CastUInt8ToInt8) {
251 CastOpModel m({TensorType_UINT8, {2, 3}}, {TensorType_INT8, {2, 3}});
252 m.PopulateTensor<uint8_t>(m.input(), {10, 20, 30, 40, 50, 60});
253 ASSERT_EQ(m.Invoke(), kTfLiteOk);
254 EXPECT_THAT(m.ExtractVector<int8_t>(m.output()),
255 ElementsAreArray({10, 20, 30, 40, 50, 60}));
256 }
257
TEST(CastOpModel,CastInt8ToUInt8)258 TEST(CastOpModel, CastInt8ToUInt8) {
259 CastOpModel m({TensorType_INT8, {2, 3}}, {TensorType_UINT8, {2, 3}});
260 m.PopulateTensor<int8_t>(m.input(), {10, 20, 30, 40, 50, 60});
261 ASSERT_EQ(m.Invoke(), kTfLiteOk);
262 EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
263 ElementsAreArray({10, 20, 30, 40, 50, 60}));
264 }
265
TEST(CastOpModel,CastUInt16ToInt16)266 TEST(CastOpModel, CastUInt16ToInt16) {
267 CastOpModel m({TensorType_UINT16, {2, 3}}, {TensorType_INT16, {2, 3}});
268 m.PopulateTensor<uint16_t>(m.input(), {10, 20, 30, 40, 50, 60});
269 ASSERT_EQ(m.Invoke(), kTfLiteOk);
270 EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
271 ElementsAreArray({10, 20, 30, 40, 50, 60}));
272 }
273
TEST(CastOpModel,CastInt16ToUInt16)274 TEST(CastOpModel, CastInt16ToUInt16) {
275 CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_UINT16, {2, 3}});
276 m.PopulateTensor<int16_t>(m.input(), {10, 20, 30, 40, 50, 60});
277 ASSERT_EQ(m.Invoke(), kTfLiteOk);
278 EXPECT_THAT(m.ExtractVector<uint16_t>(m.output()),
279 ElementsAreArray({10, 20, 30, 40, 50, 60}));
280 }
281
282 } // namespace
283 } // namespace tflite
284