xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeRsqrt.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_Rsqrt")
12 {
13 struct RsqrtFixture : public ParserFlatbuffersSerializeFixture
14 {
RsqrtFixtureRsqrtFixture15     explicit RsqrtFixture(const std::string & inputShape,
16                           const std::string & outputShape,
17                           const std::string & dataType)
18     {
19         m_JsonString = R"(
20         {
21                 inputIds: [0],
22                 outputIds: [2],
23                 layers: [
24                 {
25                     layer_type: "InputLayer",
26                     layer: {
27                           base: {
28                                 layerBindingId: 0,
29                                 base: {
30                                     index: 0,
31                                     layerName: "InputLayer",
32                                     layerType: "Input",
33                                     inputSlots: [{
34                                         index: 0,
35                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
36                                     }],
37                                     outputSlots: [ {
38                                         index: 0,
39                                         tensorInfo: {
40                                             dimensions: )" + inputShape + R"(,
41                                             dataType: )" + dataType + R"(
42                                         },
43                                     }],
44                                  },}},
45                 },
46                 {
47                 layer_type: "RsqrtLayer",
48                 layer : {
49                         base: {
50                              index:1,
51                              layerName: "RsqrtLayer",
52                              layerType: "Rsqrt",
53                              inputSlots: [
54                                             {
55                                              index: 0,
56                                              connection: {sourceLayerIndex:0, outputSlotIndex:0 },
57                                             }
58                              ],
59                              outputSlots: [ {
60                                  index: 0,
61                                  tensorInfo: {
62                                      dimensions: )" + outputShape + R"(,
63                                      dataType: )" + dataType + R"(
64                                  },
65                              }],
66                             }},
67                 },
68                 {
69                 layer_type: "OutputLayer",
70                 layer: {
71                         base:{
72                               layerBindingId: 0,
73                               base: {
74                                     index: 2,
75                                     layerName: "OutputLayer",
76                                     layerType: "Output",
77                                     inputSlots: [{
78                                         index: 0,
79                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
80                                     }],
81                                     outputSlots: [ {
82                                         index: 0,
83                                         tensorInfo: {
84                                             dimensions: )" + outputShape + R"(,
85                                             dataType: )" + dataType + R"(
86                                         },
87                                 }],
88                             }}},
89                 }]
90          }
91         )";
92         Setup();
93     }
94 };
95 
96 
97 struct Rsqrt2dFixture : RsqrtFixture
98 {
Rsqrt2dFixtureRsqrt2dFixture99     Rsqrt2dFixture() : RsqrtFixture("[ 2, 2 ]",
100                                     "[ 2, 2 ]",
101                                     "Float32") {}
102 };
103 
104 TEST_CASE_FIXTURE(Rsqrt2dFixture, "Rsqrt2d")
105 {
106   RunTest<2, armnn::DataType::Float32>(
107       0,
108       {{"InputLayer", { 1.0f,  4.0f,
109                         16.0f, 25.0f }}},
110       {{"OutputLayer",{ 1.0f,  0.5f,
111                         0.25f, 0.2f }}});
112 }
113 
114 
115 }
116