xref: /aosp_15_r20/external/XNNPACK/test/transpose-normalization-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/normalization.h>
12 
13 
14 class TransposeNormalizationTester {
15  public:
num_dims(size_t num_dims)16   inline TransposeNormalizationTester& num_dims(size_t num_dims) {
17     assert(num_dims != 0);
18     this->num_dims_ = num_dims;
19     return *this;
20   }
21 
num_dims()22   inline size_t num_dims() const { return this->num_dims_; }
23 
element_size(size_t element_size)24   inline TransposeNormalizationTester& element_size(size_t element_size) {
25     this->element_size_ = element_size;
26     return *this;
27   }
28 
element_size()29   inline size_t element_size() const { return this->element_size_; }
30 
expected_dims(size_t expected_dims)31   inline TransposeNormalizationTester& expected_dims(size_t expected_dims) {
32     this->expected_dims_ = expected_dims;
33     return *this;
34   }
35 
expected_dims()36   inline size_t expected_dims() const { return this->expected_dims_; }
37 
expected_element_size(size_t expected_element_size)38   inline TransposeNormalizationTester& expected_element_size(size_t expected_element_size) {
39     this->expected_element_size_ = expected_element_size;
40     return *this;
41   }
42 
expected_element_size()43   inline size_t expected_element_size() const { return this->expected_element_size_; }
44 
shape(const std::vector<size_t> shape)45   inline TransposeNormalizationTester& shape(const std::vector<size_t> shape) {
46     assert(shape.size() <= XNN_MAX_TENSOR_DIMS);
47     this->shape_ = shape;
48     return *this;
49   }
50 
perm(const std::vector<size_t> perm)51   inline TransposeNormalizationTester& perm(const std::vector<size_t> perm) {
52     assert(perm.size() <= XNN_MAX_TENSOR_DIMS);
53     this->perm_ = perm;
54     return *this;
55   }
56 
input_stride(const std::vector<size_t> input_stride)57   inline TransposeNormalizationTester& input_stride(const std::vector<size_t> input_stride) {
58     assert(input_stride.size() <= XNN_MAX_TENSOR_DIMS);
59     this->input_stride_ = input_stride;
60     return *this;
61   }
62 
output_stride(const std::vector<size_t> output_stride)63   inline TransposeNormalizationTester& output_stride(const std::vector<size_t> output_stride) {
64     assert(output_stride.size() <= XNN_MAX_TENSOR_DIMS);
65     this->output_stride_ = output_stride;
66     return *this;
67   }
68 
expected_shape(const std::vector<size_t> expected_shape)69   inline TransposeNormalizationTester& expected_shape(const std::vector<size_t> expected_shape) {
70     this->expected_shape_ = expected_shape;
71     return *this;
72   }
73 
expected_shape()74   inline const std::vector<size_t>& expected_shape() const { return this->expected_shape_; }
75 
expected_perm(const std::vector<size_t> expected_perm)76   inline TransposeNormalizationTester& expected_perm(const std::vector<size_t> expected_perm) {
77     this->expected_perm_ = expected_perm;
78     return *this;
79   }
80 
expected_perm()81   inline const std::vector<size_t>& expected_perm() const { return this->expected_perm_; }
82 
expected_input_stride(const std::vector<size_t> expected_input_stride)83   inline TransposeNormalizationTester& expected_input_stride(const std::vector<size_t> expected_input_stride) {
84     this->expected_input_stride_ = expected_input_stride;
85     return *this;
86   }
87 
expected_output_stride(const std::vector<size_t> expected_output_stride)88   inline TransposeNormalizationTester& expected_output_stride(const std::vector<size_t> expected_output_stride) {
89     this->expected_output_stride_ = expected_output_stride;
90     return *this;
91   }
92 
expected_input_stride()93   inline const std::vector<size_t>& expected_input_stride() const { return this->expected_input_stride_; }
94 
expected_output_stride()95   inline const std::vector<size_t>& expected_output_stride() const { return this->expected_output_stride_; }
96 
calculate_expected_input_stride()97   inline TransposeNormalizationTester& calculate_expected_input_stride() {
98     expected_input_stride_.resize(expected_dims());
99     expected_input_stride_[expected_dims() - 1] = expected_element_size();
100     for(size_t i = expected_dims() - 1; i-- != 0;) {
101       expected_input_stride_[i] = expected_input_stride_[i + 1] * expected_shape_[i + 1];
102     }
103     return *this;
104   }
105 
calculate_expected_output_stride()106   inline TransposeNormalizationTester& calculate_expected_output_stride() {
107     expected_output_stride_.resize(expected_dims());
108     expected_output_stride_[expected_dims() - 1] = expected_element_size();
109     for(size_t i = expected_dims() - 1; i-- != 0;) {
110       expected_output_stride_[i] = expected_output_stride_[i + 1]
111           * expected_shape_[expected_perm_[i + 1]];
112     }
113     return *this;
114   }
115 
Test()116   void Test() const {
117     size_t actual_element_size;
118     size_t actual_normalized_dims;
119     std::vector<size_t> actual_normalized_shape(num_dims());
120     std::vector<size_t> actual_normalized_perm(num_dims());
121     std::vector<size_t> actual_normalized_input_stride(num_dims());
122     std::vector<size_t> actual_normalized_output_stride(num_dims());
123 
124     xnn_normalize_transpose_permutation(num_dims(), element_size(), perm_.data(),
125                                         shape_.data(), input_stride_.empty() ? nullptr : input_stride_.data(),
126                                         output_stride_.empty() ? nullptr : output_stride_.data(),
127                                         &actual_normalized_dims, &actual_element_size, actual_normalized_perm.data(),
128                                         actual_normalized_shape.data(), actual_normalized_input_stride.data(),
129                                         actual_normalized_output_stride.data());
130     EXPECT_EQ(expected_element_size(), actual_element_size);
131     EXPECT_EQ(expected_dims(), actual_normalized_dims);
132 
133     for (size_t i = 0; i < expected_dims(); ++i) {
134       EXPECT_EQ(expected_shape()[i], actual_normalized_shape[i]);
135       EXPECT_EQ(expected_perm()[i], actual_normalized_perm[i]);
136       EXPECT_EQ(expected_input_stride()[i], actual_normalized_input_stride[i]);
137       EXPECT_EQ(expected_output_stride()[i], actual_normalized_output_stride[i]);
138     }
139   }
140 
141  private:
142   size_t num_dims_;
143   size_t element_size_;
144   size_t expected_dims_;
145   size_t expected_element_size_;
146   std::vector<size_t> shape_;
147   std::vector<size_t> perm_;
148   std::vector<size_t> input_stride_;
149   std::vector<size_t> output_stride_;
150   std::vector<size_t> expected_shape_;
151   std::vector<size_t> expected_perm_;
152   std::vector<size_t> expected_input_stride_;
153   std::vector<size_t> expected_output_stride_;
154 };
155