xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Transpose.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_Transpose")
9 {
10 struct TransposeFixture : public ParserFlatbuffersFixture
11 {
TransposeFixtureTransposeFixture12     explicit TransposeFixture(const std::string & inputShape,
13                               const std::string & permuteData,
14                               const std::string & outputShape)
15     {
16         m_JsonString = R"(
17             {
18                   "version": 3,
19                   "operator_codes": [
20                     {
21                       "builtin_code": "TRANSPOSE",
22                       "version": 1
23                     }
24                   ],
25                   "subgraphs": [
26                     {
27                       "tensors": [
28                         {
29                           "shape": )" + inputShape + R"(,
30                           "type": "FLOAT32",
31                           "buffer": 0,
32                           "name": "inputTensor",
33                           "quantization": {
34                             "min": [
35                               0.0
36                             ],
37                             "max": [
38                               255.0
39                             ],
40                             "details_type": 0,
41                             "quantized_dimension": 0
42                           },
43                           "is_variable": false
44                         },
45                         {
46                           "shape": )" + outputShape + R"(,
47                           "type": "FLOAT32",
48                           "buffer": 1,
49                           "name": "outputTensor",
50                           "quantization": {
51                             "details_type": 0,
52                             "quantized_dimension": 0
53                           },
54                           "is_variable": false
55                         })";
56         m_JsonString += R"(,
57                           {
58                             "shape": [
59                               3
60                             ],
61                             "type": "INT32",
62                             "buffer": 2,
63                             "name": "permuteTensor",
64                             "quantization": {
65                               "details_type": 0,
66                               "quantized_dimension": 0
67                             },
68                             "is_variable": false
69                           })";
70         m_JsonString += R"(],
71                       "inputs": [
72                         0
73                       ],
74                       "outputs": [
75                         1
76                       ],
77                       "operators": [
78                         {
79                           "opcode_index": 0,
80                           "inputs": [
81                             0)";
82         m_JsonString += R"(,2)";
83         m_JsonString += R"(],
84                           "outputs": [
85                             1
86                           ],
87                           "builtin_options_type": "TransposeOptions",
88                           "builtin_options": {
89                           },
90                           "custom_options_format": "FLEXBUFFERS"
91                         }
92                       ]
93                     }
94                   ],
95                   "description": "TOCO Converted.",
96                   "buffers": [
97                     { },
98                     { })";
99         if (!permuteData.empty())
100         {
101             m_JsonString += R"(,{"data": )" + permuteData + R"( })";
102         }
103         m_JsonString += R"(
104                   ]
105                 }
106         )";
107         Setup();
108     }
109 };
110 
111 // Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
112 struct TransposeFixtureWithPermuteData : TransposeFixture
113 {
TransposeFixtureWithPermuteDataTransposeFixtureWithPermuteData114     TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
115                                                          "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
116                                                          "[ 2, 3, 2 ]") {}
117 };
118 
119 TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData")
120 {
121     RunTest<3, armnn::DataType::Float32>(
122       0,
123       {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
124       {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
125 
126     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
127                 == armnn::TensorShape({2,3,2})));
128 }
129 
130 // Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
131 // where n is the number of dimensions of the input tensor
132 // In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
133 struct TransposeFixtureWithoutPermuteData : TransposeFixture
134 {
TransposeFixtureWithoutPermuteDataTransposeFixtureWithoutPermuteData135     TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
136                                                             "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
137                                                             "[ 3, 2, 2 ]") {}
138 };
139 
140 TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims")
141 {
142     RunTest<3, armnn::DataType::Float32>(
143         0,
144         {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
145         {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
146 
147     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
148                 == armnn::TensorShape({3,2,2})));
149 }
150 
151 }