1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
18
19 #include <stddef.h>
20
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <ostream>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29
30 #include <gmock/gmock.h>
31 #include <gtest/gtest.h>
32 #include "absl/status/status.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/lite/interpreter.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/string_type.h"
37
38 namespace tflite {
39
40 // These two functions implement usability printing for TfLiteTensor dimensions
41 // and coordinates. By default dimensions are interpreted depending on the size:
42 // 1:Linear, 2:HW, 3: HWC, 4:BHWC. If there are more than 4 dimensions,
43 // absl::nullopt will be returned.
44 std::optional<std::string> ShapeToString(TfLiteIntArray* shape);
45 std::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
46 int linear);
47
48 template <typename TupleMatcher>
49 class TensorEqMatcher {
50 public:
TensorEqMatcher(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)51 TensorEqMatcher(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
52 : tuple_matcher_(tuple_matcher), rhs_(rhs) {}
53
54 // Make TensorEqMatcher movable only (The copy operations are implicitly
55 // deleted).
56 TensorEqMatcher(TensorEqMatcher&& other) = default;
57 TensorEqMatcher& operator=(TensorEqMatcher&& other) = default;
58
59 template <typename T>
60 operator testing::Matcher<T>() const { // NOLINT
61 return testing::Matcher<T>(new Impl(tuple_matcher_, rhs_));
62 }
63
64 class Impl : public testing::MatcherInterface<TfLiteTensor> {
65 public:
66 typedef ::std::tuple<float, float> InnerMatcherArg;
67
Impl(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)68 Impl(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
69 : mono_tuple_matcher_(
70 testing::SafeMatcherCast<InnerMatcherArg>(tuple_matcher)),
71 rhs_(rhs) {}
72
73 // Make Impl movable only (The copy operations are implicitly deleted).
74 Impl(Impl&& other) = default;
75 Impl& operator=(Impl&& other) = default;
76
77 // Define what gtest framework will print for the Expected field.
DescribeTo(std::ostream * os)78 void DescribeTo(std::ostream* os) const override {
79 std::string shape;
80 std::optional<std::string> result = ShapeToString(rhs_.dims);
81 if (result.has_value()) {
82 shape = std::move(result.value());
83 } else {
84 shape = "[error: unsupported number of dimensions]";
85 }
86 *os << "tensor which has the shape of " << shape
87 << ", where each value and its corresponding expected value ";
88 mono_tuple_matcher_.DescribeTo(os);
89 }
90
MatchAndExplain(TfLiteTensor lhs,testing::MatchResultListener * listener)91 bool MatchAndExplain(
92 TfLiteTensor lhs,
93 testing::MatchResultListener* listener) const override {
94 // 1. Check that TfLiteTensor data type is supported.
95 // Support for other data types will be added on demand.
96 if (lhs.type != kTfLiteFloat32 || rhs_.type != kTfLiteFloat32) {
97 *listener << "which data type is not float32, which is not currently "
98 "supported.";
99 return false;
100 }
101
102 // 2. Check that dimensions' sizes match. Otherwise, we are not able to
103 // compare tensors.
104 if (lhs.dims->size != rhs_.dims->size) {
105 *listener << "which is different from the expected shape of size "
106 << rhs_.dims->size;
107 return false;
108 }
109 // 3. Check that dimensions' values are equal as well. We are not able to
110 // compare tensors of different shapes, even if the total elements count
111 // matches.
112 bool dims_are_equal = true;
113 for (int i = 0; i < lhs.dims->size; i++) {
114 dims_are_equal &= lhs.dims->data[i] == rhs_.dims->data[i];
115 }
116 if (!dims_are_equal) {
117 std::string shape;
118 std::optional<std::string> result = ShapeToString(rhs_.dims);
119 if (result.has_value()) {
120 shape = std::move(result.value());
121 } else {
122 shape = "[error: unsupported number of dimensions]";
123 }
124 *listener << "which is different from the expected shape " << shape;
125 return false;
126 }
127
128 // 4. Proceed to data comparison. Iterate through elements as they lay
129 // flat. If some pair of elements don't match, deduct the coordinate
130 // basing on the dimensions, then return.
131 absl::Span<float> lhs_span(lhs.data.f, lhs.bytes / sizeof(float));
132 absl::Span<float> rhs_span(rhs_.data.f, rhs_.bytes / sizeof(float));
133
134 auto left = lhs_span.begin();
135 auto right = rhs_span.begin();
136 for (size_t i = 0; i != lhs_span.size(); ++i, ++left, ++right) {
137 if (listener->IsInterested()) {
138 testing::StringMatchResultListener inner_listener;
139 if (!mono_tuple_matcher_.MatchAndExplain({*left, *right},
140 &inner_listener)) {
141 *listener << "where the value pair (";
142 testing::internal::UniversalPrint(*left, listener->stream());
143 *listener << ", ";
144 testing::internal::UniversalPrint(*right, listener->stream());
145 std::string coordinate;
146 std::optional<std::string> result = CoordinateToString(lhs.dims, i);
147 if (result.has_value()) {
148 coordinate = std::move(result.value());
149 } else {
150 coordinate = "[error: unsupported number of dimensions]";
151 }
152 *listener << ") with coordinate " << coordinate << " don't match";
153 testing::internal::PrintIfNotEmpty(inner_listener.str(),
154 listener->stream());
155 return false;
156 }
157 } else {
158 if (!mono_tuple_matcher_.Matches({*left, *right})) return false;
159 }
160 }
161
162 return true;
163 }
164
165 private:
166 const testing::Matcher<InnerMatcherArg> mono_tuple_matcher_;
167 const TfLiteTensor rhs_;
168 };
169
170 private:
171 const TupleMatcher tuple_matcher_;
172 const TfLiteTensor rhs_;
173 };
174
175 // Builds interpreter for a model, allocates tensors.
176 absl::Status BuildInterpreter(const Model* model,
177 std::unique_ptr<Interpreter>* interpreter);
178
179 // Allocates tensors for a given interpreter.
180 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter);
181
182 // Modifies graph with given delegate.
183 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
184 TfLiteDelegate* delegate);
185
186 // Initializes inputs with consequent values of some fixed range.
187 void InitializeInputs(int left, int right,
188 std::unique_ptr<Interpreter>* interpreter);
189
190 // Invokes a prebuilt interpreter.
191 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter);
192
193 // Usability structure, which is used to pass parameters data to parameterized
194 // tests.
195 struct TestParams {
196 // A gtest name, which will be used for a generated tests.
197 std::string name;
198
199 // Function, which returns a TFLite model, associated with this test name.
200 std::vector<uint8_t> model;
201 };
202
203 // Defines how the TestParams should be printed into the command line if
204 // something fails during testing.
205 std::ostream& operator<<(std::ostream& os, const TestParams& param);
206
207 } // namespace tflite
208
209 // Gtest framework uses this function to describe TfLiteTensor if something
210 // fails. TfLiteTensor is defined in global namespace, same should be done for
211 // streaming operator.
212 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor);
213
214 // Defines a matcher to compare two TfLiteTensors pointwise using the given
215 // tuple matcher for comparing their values.
216 template <typename TupleMatcherT>
TensorEq(const TupleMatcherT & matcher,const TfLiteTensor & rhs)217 inline tflite::TensorEqMatcher<TupleMatcherT> TensorEq(
218 const TupleMatcherT& matcher, const TfLiteTensor& rhs) {
219 return tflite::TensorEqMatcher<TupleMatcherT>(matcher, rhs);
220 }
221
222 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
223