xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
18 
19 #include <stddef.h>
20 
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <ostream>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 
30 #include <gmock/gmock.h>
31 #include <gtest/gtest.h>
32 #include "absl/status/status.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/lite/interpreter.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/string_type.h"
37 
38 namespace tflite {
39 
40 // These two functions implement usability printing for TfLiteTensor dimensions
41 // and coordinates. By default dimensions are interpreted depending on the size:
42 // 1:Linear, 2:HW, 3: HWC, 4:BHWC. If there are more than 4 dimensions,
43 // absl::nullopt will be returned.
44 std::optional<std::string> ShapeToString(TfLiteIntArray* shape);
45 std::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
46                                               int linear);
47 
48 template <typename TupleMatcher>
49 class TensorEqMatcher {
50  public:
TensorEqMatcher(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)51   TensorEqMatcher(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
52       : tuple_matcher_(tuple_matcher), rhs_(rhs) {}
53 
54   // Make TensorEqMatcher movable only (The copy operations are implicitly
55   // deleted).
56   TensorEqMatcher(TensorEqMatcher&& other) = default;
57   TensorEqMatcher& operator=(TensorEqMatcher&& other) = default;
58 
59   template <typename T>
60   operator testing::Matcher<T>() const {  // NOLINT
61     return testing::Matcher<T>(new Impl(tuple_matcher_, rhs_));
62   }
63 
64   class Impl : public testing::MatcherInterface<TfLiteTensor> {
65    public:
66     typedef ::std::tuple<float, float> InnerMatcherArg;
67 
Impl(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)68     Impl(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
69         : mono_tuple_matcher_(
70               testing::SafeMatcherCast<InnerMatcherArg>(tuple_matcher)),
71           rhs_(rhs) {}
72 
73     // Make Impl movable only (The copy operations are implicitly deleted).
74     Impl(Impl&& other) = default;
75     Impl& operator=(Impl&& other) = default;
76 
77     // Define what gtest framework will print for the Expected field.
DescribeTo(std::ostream * os)78     void DescribeTo(std::ostream* os) const override {
79       std::string shape;
80       std::optional<std::string> result = ShapeToString(rhs_.dims);
81       if (result.has_value()) {
82         shape = std::move(result.value());
83       } else {
84         shape = "[error: unsupported number of dimensions]";
85       }
86       *os << "tensor which has the shape of " << shape
87           << ", where each value and its corresponding expected value ";
88       mono_tuple_matcher_.DescribeTo(os);
89     }
90 
MatchAndExplain(TfLiteTensor lhs,testing::MatchResultListener * listener)91     bool MatchAndExplain(
92         TfLiteTensor lhs,
93         testing::MatchResultListener* listener) const override {
94       // 1. Check that TfLiteTensor data type is supported.
95       // Support for other data types will be added on demand.
96       if (lhs.type != kTfLiteFloat32 || rhs_.type != kTfLiteFloat32) {
97         *listener << "which data type is not float32, which is not currently "
98                      "supported.";
99         return false;
100       }
101 
102       // 2. Check that dimensions' sizes match. Otherwise, we are not able to
103       // compare tensors.
104       if (lhs.dims->size != rhs_.dims->size) {
105         *listener << "which is different from the expected shape of size "
106                   << rhs_.dims->size;
107         return false;
108       }
109       // 3. Check that dimensions' values are equal as well. We are not able to
110       // compare tensors of different shapes, even if the total elements count
111       // matches.
112       bool dims_are_equal = true;
113       for (int i = 0; i < lhs.dims->size; i++) {
114         dims_are_equal &= lhs.dims->data[i] == rhs_.dims->data[i];
115       }
116       if (!dims_are_equal) {
117         std::string shape;
118         std::optional<std::string> result = ShapeToString(rhs_.dims);
119         if (result.has_value()) {
120           shape = std::move(result.value());
121         } else {
122           shape = "[error: unsupported number of dimensions]";
123         }
124         *listener << "which is different from the expected shape " << shape;
125         return false;
126       }
127 
128       // 4. Proceed to data comparison. Iterate through elements as they lay
129       // flat. If some pair of elements don't match, deduct the coordinate
130       // basing on the dimensions, then return.
131       absl::Span<float> lhs_span(lhs.data.f, lhs.bytes / sizeof(float));
132       absl::Span<float> rhs_span(rhs_.data.f, rhs_.bytes / sizeof(float));
133 
134       auto left = lhs_span.begin();
135       auto right = rhs_span.begin();
136       for (size_t i = 0; i != lhs_span.size(); ++i, ++left, ++right) {
137         if (listener->IsInterested()) {
138           testing::StringMatchResultListener inner_listener;
139           if (!mono_tuple_matcher_.MatchAndExplain({*left, *right},
140                                                    &inner_listener)) {
141             *listener << "where the value pair (";
142             testing::internal::UniversalPrint(*left, listener->stream());
143             *listener << ", ";
144             testing::internal::UniversalPrint(*right, listener->stream());
145             std::string coordinate;
146             std::optional<std::string> result = CoordinateToString(lhs.dims, i);
147             if (result.has_value()) {
148               coordinate = std::move(result.value());
149             } else {
150               coordinate = "[error: unsupported number of dimensions]";
151             }
152             *listener << ") with coordinate " << coordinate << " don't match";
153             testing::internal::PrintIfNotEmpty(inner_listener.str(),
154                                                listener->stream());
155             return false;
156           }
157         } else {
158           if (!mono_tuple_matcher_.Matches({*left, *right})) return false;
159         }
160       }
161 
162       return true;
163     }
164 
165    private:
166     const testing::Matcher<InnerMatcherArg> mono_tuple_matcher_;
167     const TfLiteTensor rhs_;
168   };
169 
170  private:
171   const TupleMatcher tuple_matcher_;
172   const TfLiteTensor rhs_;
173 };
174 
175 // Builds interpreter for a model, allocates tensors.
176 absl::Status BuildInterpreter(const Model* model,
177                               std::unique_ptr<Interpreter>* interpreter);
178 
179 // Allocates tensors for a given interpreter.
180 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter);
181 
182 // Modifies graph with given delegate.
183 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
184                                      TfLiteDelegate* delegate);
185 
186 // Initializes inputs with consequent values of some fixed range.
187 void InitializeInputs(int left, int right,
188                       std::unique_ptr<Interpreter>* interpreter);
189 
190 // Invokes a prebuilt interpreter.
191 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter);
192 
193 // Usability structure, which is used to pass parameters data to parameterized
194 // tests.
195 struct TestParams {
196   // A gtest name, which will be used for a generated tests.
197   std::string name;
198 
199   // Function, which returns a TFLite model, associated with this test name.
200   std::vector<uint8_t> model;
201 };
202 
203 // Defines how the TestParams should be printed into the command line if
204 // something fails during testing.
205 std::ostream& operator<<(std::ostream& os, const TestParams& param);
206 
207 }  // namespace tflite
208 
209 // Gtest framework uses this function to describe TfLiteTensor if something
210 // fails. TfLiteTensor is defined in global namespace, same should be done for
211 // streaming operator.
212 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor);
213 
214 // Defines a matcher to compare two TfLiteTensors pointwise using the given
215 // tuple matcher for comparing their values.
216 template <typename TupleMatcherT>
TensorEq(const TupleMatcherT & matcher,const TfLiteTensor & rhs)217 inline tflite::TensorEqMatcher<TupleMatcherT> TensorEq(
218     const TupleMatcherT& matcher, const TfLiteTensor& rhs) {
219   return tflite::TensorEqMatcher<TupleMatcherT>(matcher, rhs);
220 }
221 
222 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
223