xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/squeeze_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 
27 namespace tflite {
28 namespace {
29 
30 using ::testing::ElementsAreArray;
31 using ::testing::IsEmpty;
32 
33 class BaseSqueezeOpModel : public SingleOpModel {
34  public:
BaseSqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)35   BaseSqueezeOpModel(const TensorData& input, const TensorData& output,
36                      std::initializer_list<int> axis) {
37     input_ = AddInput(input);
38     output_ = AddOutput(output);
39     SetBuiltinOp(
40         BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
41         CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
42             .Union());
43     BuildInterpreter({GetShape(input_)});
44   }
45 
input()46   int input() { return input_; }
47 
48  protected:
49   int input_;
50   int output_;
51 };
52 
53 template <typename T>
54 class SqueezeOpModel : public BaseSqueezeOpModel {
55  public:
56   using BaseSqueezeOpModel::BaseSqueezeOpModel;
57 
SetInput(std::initializer_list<T> data)58   void SetInput(std::initializer_list<T> data) { PopulateTensor(input_, data); }
59 
SetStringInput(std::initializer_list<string> data)60   void SetStringInput(std::initializer_list<string> data) {
61     PopulateStringTensor(input_, data);
62   }
63 
GetOutput()64   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetStringOutput()65   std::vector<string> GetStringOutput() {
66     return ExtractVector<string>(output_);
67   }
GetOutputShape()68   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
69 };
70 
71 template <typename T>
72 class SqueezeOpTest : public ::testing::Test {};
73 
74 using DataTypes = ::testing::Types<float, int8_t, int16_t, int32_t>;
75 TYPED_TEST_SUITE(SqueezeOpTest, DataTypes);
76 
TYPED_TEST(SqueezeOpTest,SqueezeAll)77 TYPED_TEST(SqueezeOpTest, SqueezeAll) {
78   std::initializer_list<TypeParam> data = {1,  2,  3,  4,  5,  6,  7,  8,
79                                            9,  10, 11, 12, 13, 14, 15, 16,
80                                            17, 18, 19, 20, 21, 22, 23, 24};
81   SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
82                               {GetTensorType<TypeParam>(), {24}}, {});
83   m.SetInput(data);
84   ASSERT_EQ(m.Invoke(), kTfLiteOk);
85   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
86   EXPECT_THAT(
87       m.GetOutput(),
88       ElementsAreArray({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
89                         13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
90 }
91 
TYPED_TEST(SqueezeOpTest,SqueezeSelectedAxis)92 TYPED_TEST(SqueezeOpTest, SqueezeSelectedAxis) {
93   std::initializer_list<TypeParam> data = {1,  2,  3,  4,  5,  6,  7,  8,
94                                            9,  10, 11, 12, 13, 14, 15, 16,
95                                            17, 18, 19, 20, 21, 22, 23, 24};
96   SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
97                               {GetTensorType<TypeParam>(), {24}}, {2});
98   m.SetInput(data);
99   ASSERT_EQ(m.Invoke(), kTfLiteOk);
100   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
101   EXPECT_THAT(
102       m.GetOutput(),
103       ElementsAreArray({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
104                         13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
105 }
106 
TYPED_TEST(SqueezeOpTest,SqueezeNegativeAxis)107 TYPED_TEST(SqueezeOpTest, SqueezeNegativeAxis) {
108   std::initializer_list<TypeParam> data = {1,  2,  3,  4,  5,  6,  7,  8,
109                                            9,  10, 11, 12, 13, 14, 15, 16,
110                                            17, 18, 19, 20, 21, 22, 23, 24};
111   SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
112                               {GetTensorType<TypeParam>(), {24}}, {-1, 0});
113   m.SetInput(data);
114   ASSERT_EQ(m.Invoke(), kTfLiteOk);
115   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
116   EXPECT_THAT(
117       m.GetOutput(),
118       ElementsAreArray({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
119                         13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
120 }
121 
TYPED_TEST(SqueezeOpTest,SqueezeAllDims)122 TYPED_TEST(SqueezeOpTest, SqueezeAllDims) {
123   std::initializer_list<TypeParam> data = {3};
124   SqueezeOpModel<TypeParam> m(
125       {GetTensorType<TypeParam>(), {1, 1, 1, 1, 1, 1, 1}},
126       {GetTensorType<TypeParam>(), {1}}, {});
127   m.SetInput(data);
128   ASSERT_EQ(m.Invoke(), kTfLiteOk);
129   EXPECT_THAT(m.GetOutputShape(), IsEmpty());
130   EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
131 }
132 
TEST(SqueezeOpTest,SqueezeAllString)133 TEST(SqueezeOpTest, SqueezeAllString) {
134   std::initializer_list<std::string> data = {"a", "b"};
135   SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
136                                 {GetTensorType<std::string>(), {2}}, {});
137   m.SetStringInput(data);
138   ASSERT_EQ(m.Invoke(), kTfLiteOk);
139   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
140   EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
141 }
142 
TEST(SqueezeOpTest,SqueezeNegativeAxisString)143 TEST(SqueezeOpTest, SqueezeNegativeAxisString) {
144   std::initializer_list<std::string> data = {"a", "b"};
145   SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
146                                 {GetTensorType<std::string>(), {24}}, {-1});
147   m.SetStringInput(data);
148   ASSERT_EQ(m.Invoke(), kTfLiteOk);
149   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
150   EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
151 }
152 
TEST(SqueezeOpTest,SqueezeAllDimsString)153 TEST(SqueezeOpTest, SqueezeAllDimsString) {
154   std::initializer_list<std::string> data = {"a"};
155   SqueezeOpModel<std::string> m(
156       {GetTensorType<std::string>(), {1, 1, 1, 1, 1, 1, 1}},
157       {GetTensorType<std::string>(), {1}}, {});
158   m.SetStringInput(data);
159   ASSERT_EQ(m.Invoke(), kTfLiteOk);
160   EXPECT_THAT(m.GetOutputShape(), IsEmpty());
161   EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"}));
162 }
163 
164 }  // namespace
165 }  // namespace tflite
166