xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeTranspose.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Transpose")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
TransposeFixtureTransposeFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit TransposeFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dimMappings,
17*89c4ff92SAndroid Build Coastguard Worker                               const std::string &outputShape,
18*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dataType)
19*89c4ff92SAndroid Build Coastguard Worker     {
20*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
21*89c4ff92SAndroid Build Coastguard Worker             {
22*89c4ff92SAndroid Build Coastguard Worker                 inputIds: [0],
23*89c4ff92SAndroid Build Coastguard Worker                 outputIds: [2],
24*89c4ff92SAndroid Build Coastguard Worker                 layers: [
25*89c4ff92SAndroid Build Coastguard Worker                     {
26*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "InputLayer",
27*89c4ff92SAndroid Build Coastguard Worker                         layer: {
28*89c4ff92SAndroid Build Coastguard Worker                             base: {
29*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
30*89c4ff92SAndroid Build Coastguard Worker                                 base: {
31*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
32*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "InputLayer",
33*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Input",
34*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
35*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
36*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37*89c4ff92SAndroid Build Coastguard Worker                                     }],
38*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
39*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
40*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
41*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + inputShape + R"(,
42*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
43*89c4ff92SAndroid Build Coastguard Worker                                         }
44*89c4ff92SAndroid Build Coastguard Worker                                     }]
45*89c4ff92SAndroid Build Coastguard Worker                                 }
46*89c4ff92SAndroid Build Coastguard Worker                             }
47*89c4ff92SAndroid Build Coastguard Worker                         }
48*89c4ff92SAndroid Build Coastguard Worker                     },
49*89c4ff92SAndroid Build Coastguard Worker                     {
50*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "TransposeLayer",
51*89c4ff92SAndroid Build Coastguard Worker                         layer: {
52*89c4ff92SAndroid Build Coastguard Worker                             base: {
53*89c4ff92SAndroid Build Coastguard Worker                                 index: 1,
54*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "TransposeLayer",
55*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "Transpose",
56*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
57*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
58*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
59*89c4ff92SAndroid Build Coastguard Worker                                 }],
60*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [{
61*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
62*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
63*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + outputShape + R"(,
64*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
65*89c4ff92SAndroid Build Coastguard Worker                                     }
66*89c4ff92SAndroid Build Coastguard Worker                                 }]
67*89c4ff92SAndroid Build Coastguard Worker                             },
68*89c4ff92SAndroid Build Coastguard Worker                             descriptor: {
69*89c4ff92SAndroid Build Coastguard Worker                                 dimMappings: )" + dimMappings + R"(,
70*89c4ff92SAndroid Build Coastguard Worker                             }
71*89c4ff92SAndroid Build Coastguard Worker                         }
72*89c4ff92SAndroid Build Coastguard Worker                     },
73*89c4ff92SAndroid Build Coastguard Worker                     {
74*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "OutputLayer",
75*89c4ff92SAndroid Build Coastguard Worker                         layer: {
76*89c4ff92SAndroid Build Coastguard Worker                             base:{
77*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 2,
78*89c4ff92SAndroid Build Coastguard Worker                                 base: {
79*89c4ff92SAndroid Build Coastguard Worker                                     index: 2,
80*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "OutputLayer",
81*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Output",
82*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
83*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
84*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
85*89c4ff92SAndroid Build Coastguard Worker                                     }],
86*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
87*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
88*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
89*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + outputShape + R"(,
90*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
91*89c4ff92SAndroid Build Coastguard Worker                                         },
92*89c4ff92SAndroid Build Coastguard Worker                                     }],
93*89c4ff92SAndroid Build Coastguard Worker                                 }
94*89c4ff92SAndroid Build Coastguard Worker                             }
95*89c4ff92SAndroid Build Coastguard Worker                         },
96*89c4ff92SAndroid Build Coastguard Worker                     }
97*89c4ff92SAndroid Build Coastguard Worker                 ]
98*89c4ff92SAndroid Build Coastguard Worker             }
99*89c4ff92SAndroid Build Coastguard Worker         )";
100*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker struct SimpleTranspose2DFixture : TransposeFixture
105*89c4ff92SAndroid Build Coastguard Worker {
SimpleTranspose2DFixtureSimpleTranspose2DFixture106*89c4ff92SAndroid Build Coastguard Worker     SimpleTranspose2DFixture() : TransposeFixture("[ 2, 3 ]",
107*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1, 0 ]",
108*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 3, 2 ]",
109*89c4ff92SAndroid Build Coastguard Worker                                                   "QuantisedAsymm8") {}
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleTranspose2DFixture, "SimpleTranspose2DQuantisedAsymm8")
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(0,
115*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 2, 3, 4, 5, 6 },
116*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 4, 2, 5, 3, 6 });
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker struct SimpleTranspose4DFixture : TransposeFixture
120*89c4ff92SAndroid Build Coastguard Worker {
SimpleTranspose4DFixtureSimpleTranspose4DFixture121*89c4ff92SAndroid Build Coastguard Worker     SimpleTranspose4DFixture() : TransposeFixture("[ 1, 2, 3, 4 ]",
122*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 3, 2, 1, 0 ]",
123*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 4, 3, 2, 1 ]",
124*89c4ff92SAndroid Build Coastguard Worker                                                   "QuantisedAsymm8") {}
125*89c4ff92SAndroid Build Coastguard Worker };
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleTranspose4DFixture, "SimpleTranspose4DQuantisedAsymm8")
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(0,
130*89c4ff92SAndroid Build Coastguard Worker                                                  {  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,
131*89c4ff92SAndroid Build Coastguard Worker                                                    13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 },
132*89c4ff92SAndroid Build Coastguard Worker                                                  {  1, 13,  5, 17,  9, 21,  2, 14,  6, 18, 10, 22,
133*89c4ff92SAndroid Build Coastguard Worker                                                     3, 15,  7, 19, 11, 23,  4, 16,  8, 20, 12, 24 });
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker }
137