1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26
27 namespace tflite {
28 namespace {
29
30 using ::testing::ElementsAreArray;
31 using ::testing::IsEmpty;
32
33 class BaseSqueezeOpModel : public SingleOpModel {
34 public:
BaseSqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)35 BaseSqueezeOpModel(const TensorData& input, const TensorData& output,
36 std::initializer_list<int> axis) {
37 input_ = AddInput(input);
38 output_ = AddOutput(output);
39 SetBuiltinOp(
40 BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
41 CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
42 .Union());
43 BuildInterpreter({GetShape(input_)});
44 }
45
input()46 int input() { return input_; }
47
48 protected:
49 int input_;
50 int output_;
51 };
52
53 template <typename T>
54 class SqueezeOpModel : public BaseSqueezeOpModel {
55 public:
56 using BaseSqueezeOpModel::BaseSqueezeOpModel;
57
SetInput(std::initializer_list<T> data)58 void SetInput(std::initializer_list<T> data) { PopulateTensor(input_, data); }
59
SetStringInput(std::initializer_list<string> data)60 void SetStringInput(std::initializer_list<string> data) {
61 PopulateStringTensor(input_, data);
62 }
63
GetOutput()64 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetStringOutput()65 std::vector<string> GetStringOutput() {
66 return ExtractVector<string>(output_);
67 }
GetOutputShape()68 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
69 };
70
71 template <typename T>
72 class SqueezeOpTest : public ::testing::Test {};
73
74 using DataTypes = ::testing::Types<float, int8_t, int16_t, int32_t>;
75 TYPED_TEST_SUITE(SqueezeOpTest, DataTypes);
76
TYPED_TEST(SqueezeOpTest,SqueezeAll)77 TYPED_TEST(SqueezeOpTest, SqueezeAll) {
78 std::initializer_list<TypeParam> data = {1, 2, 3, 4, 5, 6, 7, 8,
79 9, 10, 11, 12, 13, 14, 15, 16,
80 17, 18, 19, 20, 21, 22, 23, 24};
81 SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
82 {GetTensorType<TypeParam>(), {24}}, {});
83 m.SetInput(data);
84 ASSERT_EQ(m.Invoke(), kTfLiteOk);
85 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
86 EXPECT_THAT(
87 m.GetOutput(),
88 ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
89 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
90 }
91
TYPED_TEST(SqueezeOpTest,SqueezeSelectedAxis)92 TYPED_TEST(SqueezeOpTest, SqueezeSelectedAxis) {
93 std::initializer_list<TypeParam> data = {1, 2, 3, 4, 5, 6, 7, 8,
94 9, 10, 11, 12, 13, 14, 15, 16,
95 17, 18, 19, 20, 21, 22, 23, 24};
96 SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
97 {GetTensorType<TypeParam>(), {24}}, {2});
98 m.SetInput(data);
99 ASSERT_EQ(m.Invoke(), kTfLiteOk);
100 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
101 EXPECT_THAT(
102 m.GetOutput(),
103 ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
104 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
105 }
106
TYPED_TEST(SqueezeOpTest,SqueezeNegativeAxis)107 TYPED_TEST(SqueezeOpTest, SqueezeNegativeAxis) {
108 std::initializer_list<TypeParam> data = {1, 2, 3, 4, 5, 6, 7, 8,
109 9, 10, 11, 12, 13, 14, 15, 16,
110 17, 18, 19, 20, 21, 22, 23, 24};
111 SqueezeOpModel<TypeParam> m({GetTensorType<TypeParam>(), {1, 24, 1}},
112 {GetTensorType<TypeParam>(), {24}}, {-1, 0});
113 m.SetInput(data);
114 ASSERT_EQ(m.Invoke(), kTfLiteOk);
115 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
116 EXPECT_THAT(
117 m.GetOutput(),
118 ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
119 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
120 }
121
TYPED_TEST(SqueezeOpTest,SqueezeAllDims)122 TYPED_TEST(SqueezeOpTest, SqueezeAllDims) {
123 std::initializer_list<TypeParam> data = {3};
124 SqueezeOpModel<TypeParam> m(
125 {GetTensorType<TypeParam>(), {1, 1, 1, 1, 1, 1, 1}},
126 {GetTensorType<TypeParam>(), {1}}, {});
127 m.SetInput(data);
128 ASSERT_EQ(m.Invoke(), kTfLiteOk);
129 EXPECT_THAT(m.GetOutputShape(), IsEmpty());
130 EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
131 }
132
TEST(SqueezeOpTest,SqueezeAllString)133 TEST(SqueezeOpTest, SqueezeAllString) {
134 std::initializer_list<std::string> data = {"a", "b"};
135 SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
136 {GetTensorType<std::string>(), {2}}, {});
137 m.SetStringInput(data);
138 ASSERT_EQ(m.Invoke(), kTfLiteOk);
139 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
140 EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
141 }
142
TEST(SqueezeOpTest,SqueezeNegativeAxisString)143 TEST(SqueezeOpTest, SqueezeNegativeAxisString) {
144 std::initializer_list<std::string> data = {"a", "b"};
145 SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
146 {GetTensorType<std::string>(), {24}}, {-1});
147 m.SetStringInput(data);
148 ASSERT_EQ(m.Invoke(), kTfLiteOk);
149 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
150 EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
151 }
152
TEST(SqueezeOpTest,SqueezeAllDimsString)153 TEST(SqueezeOpTest, SqueezeAllDimsString) {
154 std::initializer_list<std::string> data = {"a"};
155 SqueezeOpModel<std::string> m(
156 {GetTensorType<std::string>(), {1, 1, 1, 1, 1, 1, 1}},
157 {GetTensorType<std::string>(), {1}}, {});
158 m.SetStringInput(data);
159 ASSERT_EQ(m.Invoke(), kTfLiteOk);
160 EXPECT_THAT(m.GetOutputShape(), IsEmpty());
161 EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"}));
162 }
163
164 } // namespace
165 } // namespace tflite
166