xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cast_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <complex>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 
31 class CastOpModel : public SingleOpModel {
32  public:
CastOpModel(const TensorData & input,const TensorData & output)33   CastOpModel(const TensorData& input, const TensorData& output) {
34     input_ = AddInput(input);
35     output_ = AddOutput(output);
36     SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions,
37                  CreateCastOptions(builder_).Union());
38     BuildInterpreter({GetShape(input_)});
39   }
40 
input() const41   int input() const { return input_; }
output() const42   int output() const { return output_; }
43 
44  protected:
45   int input_;
46   int output_;
47 };
48 
TEST(CastOpModel,CastInt16ToFloat)49 TEST(CastOpModel, CastInt16ToFloat) {
50   CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
51   m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
52   ASSERT_EQ(m.Invoke(), kTfLiteOk);
53   EXPECT_THAT(m.ExtractVector<float>(m.output()),
54               ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
55 }
56 
TEST(CastOpModel,CastInt16ToInt32)57 TEST(CastOpModel, CastInt16ToInt32) {
58   CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
59   m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
60   ASSERT_EQ(m.Invoke(), kTfLiteOk);
61   EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
62               ElementsAreArray({100, 200, 300, 400, 500, 600}));
63 }
64 
TEST(CastOpModel,CastInt32ToFloat)65 TEST(CastOpModel, CastInt32ToFloat) {
66   CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
67   m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
68   ASSERT_EQ(m.Invoke(), kTfLiteOk);
69   EXPECT_THAT(m.ExtractVector<float>(m.output()),
70               ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
71 }
72 
TEST(CastOpModel,CastFloatToInt32)73 TEST(CastOpModel, CastFloatToInt32) {
74   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
75   m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
76   ASSERT_EQ(m.Invoke(), kTfLiteOk);
77   EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
78               ElementsAreArray({100, 20, 3, 0, 0, 1}));
79 }
80 
TEST(CastOpModel,CastFloatToInt16)81 TEST(CastOpModel, CastFloatToInt16) {
82   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
83   m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
84   ASSERT_EQ(m.Invoke(), kTfLiteOk);
85   EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
86               ElementsAreArray({100, 20, 3, 0, 0, 1}));
87 }
88 
TEST(CastOpModel,CastInt64ToFloat)89 TEST(CastOpModel, CastInt64ToFloat) {
90   CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
91   m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
92   ASSERT_EQ(m.Invoke(), kTfLiteOk);
93   EXPECT_THAT(m.ExtractVector<float>(m.output()),
94               ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
95 }
96 
TEST(CastOpModel,CastFloatToInt64)97 TEST(CastOpModel, CastFloatToInt64) {
98   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}});
99   m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
100   ASSERT_EQ(m.Invoke(), kTfLiteOk);
101   EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
102               ElementsAreArray({100, 20, 3, 0, 0, 1}));
103 }
104 
TEST(CastOpModel,CastFloatToBool)105 TEST(CastOpModel, CastFloatToBool) {
106   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}});
107   m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f});
108   ASSERT_EQ(m.Invoke(), kTfLiteOk);
109   EXPECT_THAT(m.ExtractVector<bool>(m.output()),
110               ElementsAreArray({true, true, false, true, true, true}));
111 }
112 
TEST(CastOpModel,CastBoolToFloat)113 TEST(CastOpModel, CastBoolToFloat) {
114   CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
115   m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true});
116   ASSERT_EQ(m.Invoke(), kTfLiteOk);
117   EXPECT_THAT(m.ExtractVector<float>(m.output()),
118               ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
119 }
120 
TEST(CastOpModel,CastFloatToUInt8)121 TEST(CastOpModel, CastFloatToUInt8) {
122   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
123   m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
124   ASSERT_EQ(m.Invoke(), kTfLiteOk);
125   EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
126               ElementsAreArray({100, 1, 0, 0, 1, 1}));
127 }
128 
TEST(CastOpModel,CastUInt8ToFloat)129 TEST(CastOpModel, CastUInt8ToFloat) {
130   CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
131   m.PopulateTensor<uint8_t>(m.input(), {123, 0, 1, 2, 3, 4});
132   ASSERT_EQ(m.Invoke(), kTfLiteOk);
133   EXPECT_THAT(m.ExtractVector<float>(m.output()),
134               ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
135 }
136 
TEST(CastOpModel,CastFloatToUInt16)137 TEST(CastOpModel, CastFloatToUInt16) {
138   CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT16, {3, 2}});
139   m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
140   ASSERT_EQ(m.Invoke(), kTfLiteOk);
141   EXPECT_THAT(m.ExtractVector<uint16_t>(m.output()),
142               ElementsAreArray({100, 1, 0, 0, 1, 1}));
143 }
144 
TEST(CastOpModel,CastUInt16ToFloat)145 TEST(CastOpModel, CastUInt16ToFloat) {
146   CastOpModel m({TensorType_UINT16, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
147   m.PopulateTensor<uint16_t>(m.input(), {123, 0, 1, 2, 3, 4});
148   ASSERT_EQ(m.Invoke(), kTfLiteOk);
149   EXPECT_THAT(m.ExtractVector<float>(m.output()),
150               ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
151 }
152 
TEST(CastOpModel,CastInt32ToUInt8)153 TEST(CastOpModel, CastInt32ToUInt8) {
154   CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
155   m.PopulateTensor<int32_t>(m.input(), {100, 1, 200, 2, 255, 3});
156   ASSERT_EQ(m.Invoke(), kTfLiteOk);
157   EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
158               ElementsAreArray({100, 1, 200, 2, 255, 3}));
159 }
160 
TEST(CastOpModel,CastUInt8ToInt32)161 TEST(CastOpModel, CastUInt8ToInt32) {
162   CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}});
163   m.PopulateTensor<uint8_t>(m.input(), {100, 1, 200, 2, 255, 3});
164   ASSERT_EQ(m.Invoke(), kTfLiteOk);
165   EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
166               ElementsAreArray({100, 1, 200, 2, 255, 3}));
167 }
168 
TEST(CastOpModel,CastComplex64ToFloat)169 TEST(CastOpModel, CastComplex64ToFloat) {
170   CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
171   m.PopulateTensor<std::complex<float>>(
172       m.input(),
173       {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
174        std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
175        std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
176   ASSERT_EQ(m.Invoke(), kTfLiteOk);
177   EXPECT_THAT(m.ExtractVector<float>(m.output()),
178               ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
179 }
180 
TEST(CastOpModel,CastFloatToComplex64)181 TEST(CastOpModel, CastFloatToComplex64) {
182   CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
183   m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
184   ASSERT_EQ(m.Invoke(), kTfLiteOk);
185   EXPECT_THAT(
186       m.ExtractVector<std::complex<float>>(m.output()),
187       ElementsAreArray(
188           {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
189            std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
190            std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
191 }
192 
TEST(CastOpModel,CastComplex64ToInt)193 TEST(CastOpModel, CastComplex64ToInt) {
194   CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
195   m.PopulateTensor<std::complex<float>>(
196       m.input(),
197       {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
198        std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
199        std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
200   ASSERT_EQ(m.Invoke(), kTfLiteOk);
201   EXPECT_THAT(m.ExtractVector<int>(m.output()),
202               ElementsAreArray({1, 2, 3, 4, 5, 6}));
203 }
204 
TEST(CastOpModel,CastIntToComplex64)205 TEST(CastOpModel, CastIntToComplex64) {
206   CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
207   m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
208   ASSERT_EQ(m.Invoke(), kTfLiteOk);
209   EXPECT_THAT(
210       m.ExtractVector<std::complex<float>>(m.output()),
211       ElementsAreArray(
212           {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
213            std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
214            std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
215 }
216 
TEST(CastOpModel,CastComplex64ToComplex64)217 TEST(CastOpModel, CastComplex64ToComplex64) {
218   CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
219   m.PopulateTensor<std::complex<float>>(
220       m.input(),
221       {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
222        std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
223        std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
224   ASSERT_EQ(m.Invoke(), kTfLiteOk);
225   EXPECT_THAT(
226       m.ExtractVector<std::complex<float>>(m.output()),
227       ElementsAreArray(
228           {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
229            std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
230            std::complex<float>(5.0f, 15.0f),
231            std::complex<float>(6.0f, 16.0f)}));
232 }
233 
TEST(CastOpModel,CastUInt32ToInt32)234 TEST(CastOpModel, CastUInt32ToInt32) {
235   CastOpModel m({TensorType_UINT32, {2, 3}}, {TensorType_INT32, {2, 3}});
236   m.PopulateTensor<uint32_t>(m.input(), {100, 200, 300, 400, 500, 600});
237   ASSERT_EQ(m.Invoke(), kTfLiteOk);
238   EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
239               ElementsAreArray({100, 200, 300, 400, 500, 600}));
240 }
241 
TEST(CastOpModel,CastInt32ToUInt32)242 TEST(CastOpModel, CastInt32ToUInt32) {
243   CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_UINT32, {2, 3}});
244   m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
245   ASSERT_EQ(m.Invoke(), kTfLiteOk);
246   EXPECT_THAT(m.ExtractVector<uint32_t>(m.output()),
247               ElementsAreArray({100, 200, 300, 400, 500, 600}));
248 }
249 
TEST(CastOpModel,CastUInt8ToInt8)250 TEST(CastOpModel, CastUInt8ToInt8) {
251   CastOpModel m({TensorType_UINT8, {2, 3}}, {TensorType_INT8, {2, 3}});
252   m.PopulateTensor<uint8_t>(m.input(), {10, 20, 30, 40, 50, 60});
253   ASSERT_EQ(m.Invoke(), kTfLiteOk);
254   EXPECT_THAT(m.ExtractVector<int8_t>(m.output()),
255               ElementsAreArray({10, 20, 30, 40, 50, 60}));
256 }
257 
TEST(CastOpModel,CastInt8ToUInt8)258 TEST(CastOpModel, CastInt8ToUInt8) {
259   CastOpModel m({TensorType_INT8, {2, 3}}, {TensorType_UINT8, {2, 3}});
260   m.PopulateTensor<int8_t>(m.input(), {10, 20, 30, 40, 50, 60});
261   ASSERT_EQ(m.Invoke(), kTfLiteOk);
262   EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
263               ElementsAreArray({10, 20, 30, 40, 50, 60}));
264 }
265 
TEST(CastOpModel,CastUInt16ToInt16)266 TEST(CastOpModel, CastUInt16ToInt16) {
267   CastOpModel m({TensorType_UINT16, {2, 3}}, {TensorType_INT16, {2, 3}});
268   m.PopulateTensor<uint16_t>(m.input(), {10, 20, 30, 40, 50, 60});
269   ASSERT_EQ(m.Invoke(), kTfLiteOk);
270   EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
271               ElementsAreArray({10, 20, 30, 40, 50, 60}));
272 }
273 
TEST(CastOpModel,CastInt16ToUInt16)274 TEST(CastOpModel, CastInt16ToUInt16) {
275   CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_UINT16, {2, 3}});
276   m.PopulateTensor<int16_t>(m.input(), {10, 20, 30, 40, 50, 60});
277   ASSERT_EQ(m.Invoke(), kTfLiteOk);
278   EXPECT_THAT(m.ExtractVector<uint16_t>(m.output()),
279               ElementsAreArray({10, 20, 30, 40, 50, 60}));
280 }
281 
282 }  // namespace
283 }  // namespace tflite
284