xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Div.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Div")
10 {
11 struct DivFixture : public ParserFlatbuffersFixture
12 {
DivFixtureDivFixture13     explicit DivFixture(const std::string & inputShape1,
14                         const std::string & inputShape2,
15                         const std::string & outputShape,
16                         const std::string & activation="NONE")
17     {
18         m_JsonString = R"(
19             {
20                 "version": 3,
21                 "operator_codes": [ { "builtin_code": "DIV" } ],
22                 "subgraphs": [ {
23                     "tensors": [
24                         {
25                             "shape": )" + inputShape1 + R"(,
26                             "type": "FLOAT32",
27                             "buffer": 0,
28                             "name": "inputTensor1",
29                             "quantization": {
30                                 "min": [ 0.0 ],
31                                 "max": [ 255.0 ],
32                                 "scale": [ 1.0 ],
33                                 "zero_point": [ 0 ],
34                             }
35                         },
36                         {
37                             "shape": )" + inputShape2 + R"(,
38                             "type": "FLOAT32",
39                             "buffer": 1,
40                             "name": "inputTensor2",
41                             "quantization": {
42                                 "min": [ 0.0 ],
43                                 "max": [ 255.0 ],
44                                 "scale": [ 1.0 ],
45                                 "zero_point": [ 0 ],
46                             }
47                         },
48                         {
49                             "shape": )" + outputShape + R"( ,
50                             "type": "FLOAT32",
51                             "buffer": 2,
52                             "name": "outputTensor",
53                             "quantization": {
54                                 "min": [ 0.0 ],
55                                 "max": [ 255.0 ],
56                                 "scale": [ 1.0 ],
57                                 "zero_point": [ 0 ],
58                             }
59                         }
60                     ],
61                     "inputs": [ 0, 1 ],
62                     "outputs": [ 2 ],
63                     "operators": [
64                         {
65                             "opcode_index": 0,
66                             "inputs": [ 0, 1 ],
67                             "outputs": [ 2 ],
68                             "builtin_options_type": "DivOptions",
69                             "builtin_options": {
70                                 "fused_activation_function": )" + activation + R"(
71                             },
72                             "custom_options_format": "FLEXBUFFERS"
73                         }
74                     ],
75                 } ],
76                 "buffers" : [
77                     { },
78                     { }
79                 ]
80             }
81         )";
82         Setup();
83     }
84 };
85 
86 struct SimpleDivFixture : public DivFixture
87 {
SimpleDivFixtureSimpleDivFixture88     SimpleDivFixture() : DivFixture("[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]") {}
89 };
90 
91 TEST_CASE_FIXTURE(SimpleDivFixture, "ParseDiv")
92 {
93     using armnn::DataType;
94     float Inf = std::numeric_limits<float>::infinity();
95     float NaN = std::numeric_limits<float>::quiet_NaN();
96 
97     RunTest<4, DataType::Float32>(0, {{ "inputTensor1", { 0.0f,  1.0f,  2.0f,
98                                                           3.0f,  4.0f,  5.0f,
99                                                           6.0f,  7.0f,  8.0f,
100                                                           9.0f, 10.0f, -11.0f } },
101                                       { "inputTensor2", { 0.0f,  0.0f,  4.0f,
102                                                           3.0f,  40.0f,  5.0f,
103                                                           6.0f,  7.0f,  8.0f,
104                                                           9.0f,  10.0f,  11.0f} } },
105                                      {{ "outputTensor", { NaN,   Inf,  0.5f,
106                                                           1.0f,  0.1f, 1.0f,
107                                                           1.0f,  1.0f, 1.0f,
108                                                           1.0f,  1.0f, -1.0f } } });
109 }
110 
111 
112 struct DynamicDivFixture : public DivFixture
113 {
DynamicDivFixtureDynamicDivFixture114     DynamicDivFixture() : DivFixture("[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]", "[  ]") {}
115 };
116 
117 TEST_CASE_FIXTURE(DynamicDivFixture, "ParseDynamicDiv")
118 {
119     using armnn::DataType;
120     float Inf = std::numeric_limits<float>::infinity();
121     float NaN = std::numeric_limits<float>::quiet_NaN();
122 
123     RunTest<4, DataType::Float32, DataType::Float32>(0, {{ "inputTensor1", { 0.0f,  1.0f,  2.0f,
124                                                             3.0f,  4.0f,  5.0f,
125                                                             6.0f,  7.0f,  8.0f,
126                                                             9.0f, 10.0f, -11.0f } },
127                                       { "inputTensor2", { 0.0f,  0.0f,  4.0f,
128                                                             3.0f,  40.0f,  5.0f,
129                                                             6.0f,  7.0f,  8.0f,
130                                                             9.0f,  10.0f,  11.0f} } },
131                                   {{ "outputTensor", { NaN,   Inf,  0.5f,
132                                                          1.0f,  0.1f, 1.0f,
133                                                          1.0f,  1.0f, 1.0f,
134                                                          1.0f,  1.0f, -1.0f } } }, true);
135 }
136 
137 }
138