xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeCast.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <string>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Cast")
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker struct CastFixture : public ParserFlatbuffersSerializeFixture
17*89c4ff92SAndroid Build Coastguard Worker {
CastFixtureCastFixture18*89c4ff92SAndroid Build Coastguard Worker     explicit CastFixture(const std::string& inputShape,
19*89c4ff92SAndroid Build Coastguard Worker                          const std::string& outputShape,
20*89c4ff92SAndroid Build Coastguard Worker                          const std::string& inputDataType,
21*89c4ff92SAndroid Build Coastguard Worker                          const std::string& outputDataType)
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
24*89c4ff92SAndroid Build Coastguard Worker             {
25*89c4ff92SAndroid Build Coastguard Worker                 inputIds: [0],
26*89c4ff92SAndroid Build Coastguard Worker                 outputIds: [2],
27*89c4ff92SAndroid Build Coastguard Worker                 layers: [
28*89c4ff92SAndroid Build Coastguard Worker                     {
29*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "InputLayer",
30*89c4ff92SAndroid Build Coastguard Worker                         layer: {
31*89c4ff92SAndroid Build Coastguard Worker                             base: {
32*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
33*89c4ff92SAndroid Build Coastguard Worker                                 base: {
34*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
35*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "inputTensor",
36*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Input",
37*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
38*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
39*89c4ff92SAndroid Build Coastguard Worker                                         connection: { sourceLayerIndex:0, outputSlotIndex:0 },
40*89c4ff92SAndroid Build Coastguard Worker                                     }],
41*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
42*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
43*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
44*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + inputShape + R"(,
45*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + inputDataType + R"(
46*89c4ff92SAndroid Build Coastguard Worker                                         }
47*89c4ff92SAndroid Build Coastguard Worker                                     }]
48*89c4ff92SAndroid Build Coastguard Worker                                 }
49*89c4ff92SAndroid Build Coastguard Worker                             }
50*89c4ff92SAndroid Build Coastguard Worker                         }
51*89c4ff92SAndroid Build Coastguard Worker                     },
52*89c4ff92SAndroid Build Coastguard Worker                     {
53*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "CastLayer",
54*89c4ff92SAndroid Build Coastguard Worker                         layer: {
55*89c4ff92SAndroid Build Coastguard Worker                             base: {
56*89c4ff92SAndroid Build Coastguard Worker                                  index:1,
57*89c4ff92SAndroid Build Coastguard Worker                                  layerName: "CastLayer",
58*89c4ff92SAndroid Build Coastguard Worker                                  layerType: "Cast",
59*89c4ff92SAndroid Build Coastguard Worker                                  inputSlots: [{
60*89c4ff92SAndroid Build Coastguard Worker                                      index: 0,
61*89c4ff92SAndroid Build Coastguard Worker                                      connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62*89c4ff92SAndroid Build Coastguard Worker                                  }],
63*89c4ff92SAndroid Build Coastguard Worker                                  outputSlots: [{
64*89c4ff92SAndroid Build Coastguard Worker                                      index: 0,
65*89c4ff92SAndroid Build Coastguard Worker                                      tensorInfo: {
66*89c4ff92SAndroid Build Coastguard Worker                                          dimensions: )" + outputShape + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                                          dataType: )" + outputDataType + R"(
68*89c4ff92SAndroid Build Coastguard Worker                                      },
69*89c4ff92SAndroid Build Coastguard Worker                                  }],
70*89c4ff92SAndroid Build Coastguard Worker                             },
71*89c4ff92SAndroid Build Coastguard Worker                         },
72*89c4ff92SAndroid Build Coastguard Worker                     },
73*89c4ff92SAndroid Build Coastguard Worker                     {
74*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "OutputLayer",
75*89c4ff92SAndroid Build Coastguard Worker                         layer: {
76*89c4ff92SAndroid Build Coastguard Worker                             base:{
77*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 2,
78*89c4ff92SAndroid Build Coastguard Worker                                 base: {
79*89c4ff92SAndroid Build Coastguard Worker                                     index: 2,
80*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "outputTensor",
81*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Output",
82*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
83*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
84*89c4ff92SAndroid Build Coastguard Worker                                         connection: { sourceLayerIndex:1, outputSlotIndex:0 },
85*89c4ff92SAndroid Build Coastguard Worker                                     }],
86*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
87*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
88*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
89*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + outputShape + R"(,
90*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + outputDataType + R"(
91*89c4ff92SAndroid Build Coastguard Worker                                         },
92*89c4ff92SAndroid Build Coastguard Worker                                     }],
93*89c4ff92SAndroid Build Coastguard Worker                                 }
94*89c4ff92SAndroid Build Coastguard Worker                             }
95*89c4ff92SAndroid Build Coastguard Worker                         },
96*89c4ff92SAndroid Build Coastguard Worker                     }
97*89c4ff92SAndroid Build Coastguard Worker                 ]
98*89c4ff92SAndroid Build Coastguard Worker             }
99*89c4ff92SAndroid Build Coastguard Worker         )";
100*89c4ff92SAndroid Build Coastguard Worker         Setup();
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker struct SimpleCastFixture : CastFixture
105*89c4ff92SAndroid Build Coastguard Worker {
SimpleCastFixtureSimpleCastFixture106*89c4ff92SAndroid Build Coastguard Worker     SimpleCastFixture() : CastFixture("[ 1, 6 ]",
107*89c4ff92SAndroid Build Coastguard Worker                                       "[ 1, 6 ]",
108*89c4ff92SAndroid Build Coastguard Worker                                       "Signed32",
109*89c4ff92SAndroid Build Coastguard Worker                                       "Float32") {}
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleCastFixture, "SimpleCast")
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Signed32 , armnn::DataType::Float32>(
115*89c4ff92SAndroid Build Coastguard Worker         0,
116*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor",  { 0,   -1,   5,   -100,   200,   -255 }}},
117*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 0.0f, -1.0f, 5.0f, -100.0f, 200.0f, -255.0f }}});
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker }
121