xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/xnnpack/concatenation_tester.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONCATENATION_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONCATENATION_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 
27 namespace tflite {
28 namespace xnnpack {
29 
30 // Creates a new shape with the same dimensions as `shape`, except for the axis
31 // dimension, which will have the value `size`.
32 std::vector<int32_t> SameShapeDifferentAxis(std::vector<int32_t> shape,
33                                             int axis, int32_t size);
34 
35 class ConcatenationTester {
36  public:
37   ConcatenationTester() = default;
38   ConcatenationTester(const ConcatenationTester&) = delete;
39   ConcatenationTester& operator=(const ConcatenationTester&) = delete;
40 
Axis(int axis)41   inline ConcatenationTester& Axis(int axis) {
42     axis_ = axis;
43     return *this;
44   }
45 
Axis()46   inline const int Axis() const { return axis_; }
47 
InputShapes(const std::initializer_list<std::vector<int32_t>> shapes)48   inline ConcatenationTester& InputShapes(
49       const std::initializer_list<std::vector<int32_t>> shapes) {
50     for (auto shape : shapes) {
51       for (auto it = shape.begin(); it != shape.end(); ++it) {
52         EXPECT_GT(*it, 0);
53       }
54     }
55     input_shapes_ = shapes;
56     return *this;
57   }
58 
InputShape(size_t i)59   inline std::vector<int32_t> InputShape(size_t i) const {
60     return input_shapes_[i];
61   }
62 
NumInputs()63   inline size_t NumInputs() const { return input_shapes_.size(); }
64 
OutputShape()65   std::vector<int32_t> OutputShape() const {
66     std::vector<int32_t> output_shape = InputShape(0);
67     int concat_axis = Axis() < 0 ? Axis() + output_shape.size() : Axis();
68     size_t axis_dim_size = 0;
69     for (size_t i = 0; i < NumInputs(); i++) {
70       axis_dim_size += InputShape(i)[concat_axis];
71     }
72     output_shape[concat_axis] = axis_dim_size;
73     return output_shape;
74   }
75 
76   template <typename T>
77   void Test(Interpreter* delegate_interpreter,
78             Interpreter* default_interpreter) const;
79   void Test(TensorType tensor_type, TfLiteDelegate* delegate) const;
80 
81  private:
82   std::vector<char> CreateTfLiteModel(TensorType tensor_type) const;
83 
84   static int32_t ComputeSize(const std::vector<int32_t>& shape);
85 
86   int axis_;
87   std::vector<int32_t> output_shape_;
88   std::vector<std::vector<int32_t>> input_shapes_;
89 };
90 
91 }  // namespace xnnpack
92 }  // namespace tflite
93 
94 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_CONCATENATION_TESTER_H_
95