1 // Copyright 2022 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <xnnpack.h> 11 #include <xnnpack/normalization.h> 12 13 14 class TransposeNormalizationTester { 15 public: num_dims(size_t num_dims)16 inline TransposeNormalizationTester& num_dims(size_t num_dims) { 17 assert(num_dims != 0); 18 this->num_dims_ = num_dims; 19 return *this; 20 } 21 num_dims()22 inline size_t num_dims() const { return this->num_dims_; } 23 element_size(size_t element_size)24 inline TransposeNormalizationTester& element_size(size_t element_size) { 25 this->element_size_ = element_size; 26 return *this; 27 } 28 element_size()29 inline size_t element_size() const { return this->element_size_; } 30 expected_dims(size_t expected_dims)31 inline TransposeNormalizationTester& expected_dims(size_t expected_dims) { 32 this->expected_dims_ = expected_dims; 33 return *this; 34 } 35 expected_dims()36 inline size_t expected_dims() const { return this->expected_dims_; } 37 expected_element_size(size_t expected_element_size)38 inline TransposeNormalizationTester& expected_element_size(size_t expected_element_size) { 39 this->expected_element_size_ = expected_element_size; 40 return *this; 41 } 42 expected_element_size()43 inline size_t expected_element_size() const { return this->expected_element_size_; } 44 shape(const std::vector<size_t> shape)45 inline TransposeNormalizationTester& shape(const std::vector<size_t> shape) { 46 assert(shape.size() <= XNN_MAX_TENSOR_DIMS); 47 this->shape_ = shape; 48 return *this; 49 } 50 perm(const std::vector<size_t> perm)51 inline TransposeNormalizationTester& perm(const std::vector<size_t> perm) { 52 assert(perm.size() <= XNN_MAX_TENSOR_DIMS); 53 this->perm_ = perm; 54 return *this; 55 } 56 input_stride(const std::vector<size_t> input_stride)57 inline TransposeNormalizationTester& input_stride(const std::vector<size_t> input_stride) { 58 assert(input_stride.size() <= XNN_MAX_TENSOR_DIMS); 59 this->input_stride_ = input_stride; 60 return *this; 61 } 62 output_stride(const std::vector<size_t> output_stride)63 inline TransposeNormalizationTester& output_stride(const std::vector<size_t> output_stride) { 64 assert(output_stride.size() <= XNN_MAX_TENSOR_DIMS); 65 this->output_stride_ = output_stride; 66 return *this; 67 } 68 expected_shape(const std::vector<size_t> expected_shape)69 inline TransposeNormalizationTester& expected_shape(const std::vector<size_t> expected_shape) { 70 this->expected_shape_ = expected_shape; 71 return *this; 72 } 73 expected_shape()74 inline const std::vector<size_t>& expected_shape() const { return this->expected_shape_; } 75 expected_perm(const std::vector<size_t> expected_perm)76 inline TransposeNormalizationTester& expected_perm(const std::vector<size_t> expected_perm) { 77 this->expected_perm_ = expected_perm; 78 return *this; 79 } 80 expected_perm()81 inline const std::vector<size_t>& expected_perm() const { return this->expected_perm_; } 82 expected_input_stride(const std::vector<size_t> expected_input_stride)83 inline TransposeNormalizationTester& expected_input_stride(const std::vector<size_t> expected_input_stride) { 84 this->expected_input_stride_ = expected_input_stride; 85 return *this; 86 } 87 expected_output_stride(const std::vector<size_t> expected_output_stride)88 inline TransposeNormalizationTester& expected_output_stride(const std::vector<size_t> expected_output_stride) { 89 this->expected_output_stride_ = expected_output_stride; 90 return *this; 91 } 92 expected_input_stride()93 inline const std::vector<size_t>& expected_input_stride() const { return this->expected_input_stride_; } 94 expected_output_stride()95 inline const std::vector<size_t>& expected_output_stride() const { return this->expected_output_stride_; } 96 calculate_expected_input_stride()97 inline TransposeNormalizationTester& calculate_expected_input_stride() { 98 expected_input_stride_.resize(expected_dims()); 99 expected_input_stride_[expected_dims() - 1] = expected_element_size(); 100 for(size_t i = expected_dims() - 1; i-- != 0;) { 101 expected_input_stride_[i] = expected_input_stride_[i + 1] * expected_shape_[i + 1]; 102 } 103 return *this; 104 } 105 calculate_expected_output_stride()106 inline TransposeNormalizationTester& calculate_expected_output_stride() { 107 expected_output_stride_.resize(expected_dims()); 108 expected_output_stride_[expected_dims() - 1] = expected_element_size(); 109 for(size_t i = expected_dims() - 1; i-- != 0;) { 110 expected_output_stride_[i] = expected_output_stride_[i + 1] 111 * expected_shape_[expected_perm_[i + 1]]; 112 } 113 return *this; 114 } 115 Test()116 void Test() const { 117 size_t actual_element_size; 118 size_t actual_normalized_dims; 119 std::vector<size_t> actual_normalized_shape(num_dims()); 120 std::vector<size_t> actual_normalized_perm(num_dims()); 121 std::vector<size_t> actual_normalized_input_stride(num_dims()); 122 std::vector<size_t> actual_normalized_output_stride(num_dims()); 123 124 xnn_normalize_transpose_permutation(num_dims(), element_size(), perm_.data(), 125 shape_.data(), input_stride_.empty() ? nullptr : input_stride_.data(), 126 output_stride_.empty() ? nullptr : output_stride_.data(), 127 &actual_normalized_dims, &actual_element_size, actual_normalized_perm.data(), 128 actual_normalized_shape.data(), actual_normalized_input_stride.data(), 129 actual_normalized_output_stride.data()); 130 EXPECT_EQ(expected_element_size(), actual_element_size); 131 EXPECT_EQ(expected_dims(), actual_normalized_dims); 132 133 for (size_t i = 0; i < expected_dims(); ++i) { 134 EXPECT_EQ(expected_shape()[i], actual_normalized_shape[i]); 135 EXPECT_EQ(expected_perm()[i], actual_normalized_perm[i]); 136 EXPECT_EQ(expected_input_stride()[i], actual_normalized_input_stride[i]); 137 EXPECT_EQ(expected_output_stride()[i], actual_normalized_output_stride[i]); 138 } 139 } 140 141 private: 142 size_t num_dims_; 143 size_t element_size_; 144 size_t expected_dims_; 145 size_t expected_element_size_; 146 std::vector<size_t> shape_; 147 std::vector<size_t> perm_; 148 std::vector<size_t> input_stride_; 149 std::vector<size_t> output_stride_; 150 std::vector<size_t> expected_shape_; 151 std::vector<size_t> expected_perm_; 152 std::vector<size_t> expected_input_stride_; 153 std::vector<size_t> expected_output_stride_; 154 }; 155