1 // 2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 TEST_SUITE("TensorflowLiteParser_Transpose") 9 { 10 struct TransposeFixture : public ParserFlatbuffersFixture 11 { TransposeFixtureTransposeFixture12 explicit TransposeFixture(const std::string & inputShape, 13 const std::string & permuteData, 14 const std::string & outputShape) 15 { 16 m_JsonString = R"( 17 { 18 "version": 3, 19 "operator_codes": [ 20 { 21 "builtin_code": "TRANSPOSE", 22 "version": 1 23 } 24 ], 25 "subgraphs": [ 26 { 27 "tensors": [ 28 { 29 "shape": )" + inputShape + R"(, 30 "type": "FLOAT32", 31 "buffer": 0, 32 "name": "inputTensor", 33 "quantization": { 34 "min": [ 35 0.0 36 ], 37 "max": [ 38 255.0 39 ], 40 "details_type": 0, 41 "quantized_dimension": 0 42 }, 43 "is_variable": false 44 }, 45 { 46 "shape": )" + outputShape + R"(, 47 "type": "FLOAT32", 48 "buffer": 1, 49 "name": "outputTensor", 50 "quantization": { 51 "details_type": 0, 52 "quantized_dimension": 0 53 }, 54 "is_variable": false 55 })"; 56 m_JsonString += R"(, 57 { 58 "shape": [ 59 3 60 ], 61 "type": "INT32", 62 "buffer": 2, 63 "name": "permuteTensor", 64 "quantization": { 65 "details_type": 0, 66 "quantized_dimension": 0 67 }, 68 "is_variable": false 69 })"; 70 m_JsonString += R"(], 71 "inputs": [ 72 0 73 ], 74 "outputs": [ 75 1 76 ], 77 "operators": [ 78 { 79 "opcode_index": 0, 80 "inputs": [ 81 0)"; 82 m_JsonString += R"(,2)"; 83 m_JsonString += R"(], 84 "outputs": [ 85 1 86 ], 87 "builtin_options_type": "TransposeOptions", 88 "builtin_options": { 89 }, 90 "custom_options_format": "FLEXBUFFERS" 91 } 92 ] 93 } 94 ], 95 "description": "TOCO Converted.", 96 "buffers": [ 97 { }, 98 { })"; 99 if (!permuteData.empty()) 100 { 101 m_JsonString += R"(,{"data": )" + permuteData + R"( })"; 102 } 103 m_JsonString += R"( 104 ] 105 } 106 )"; 107 Setup(); 108 } 109 }; 110 111 // Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation. 112 struct TransposeFixtureWithPermuteData : TransposeFixture 113 { TransposeFixtureWithPermuteDataTransposeFixtureWithPermuteData114 TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]", 115 "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]", 116 "[ 2, 3, 2 ]") {} 117 }; 118 119 TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData") 120 { 121 RunTest<3, armnn::DataType::Float32>( 122 0, 123 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, 124 {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}}); 125 126 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 127 == armnn::TensorShape({2,3,2}))); 128 } 129 130 // Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0], 131 // where n is the number of dimensions of the input tensor 132 // In this case we should get output shape 3,2,2 given default permutation vector 2,1,0 133 struct TransposeFixtureWithoutPermuteData : TransposeFixture 134 { TransposeFixtureWithoutPermuteDataTransposeFixtureWithoutPermuteData135 TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]", 136 "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]", 137 "[ 3, 2, 2 ]") {} 138 }; 139 140 TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims") 141 { 142 RunTest<3, armnn::DataType::Float32>( 143 0, 144 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, 145 {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}}); 146 147 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 148 == armnn::TensorShape({3,2,2}))); 149 } 150 151 }