xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/BatchMatMul.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_BatchMatMul")
9 {
10 struct BatchMatMulFixture : public ParserFlatbuffersFixture
11 {
BatchMatMulFixtureBatchMatMulFixture12     explicit BatchMatMulFixture(const std::string &inputXShape,
13                                 const std::string &inputYShape,
14                                 const std::string &outputShape,
15                                 const std::string &tranX,
16                                 const std::string &tranY)
17     {
18         m_JsonString = R"(
19             {
20                 "version": 3,
21                 "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ],
22                 "subgraphs": [
23                     {
24                         "tensors": [
25                             {
26                                 "shape": )" + inputXShape + R"(,
27                                 "type": "FLOAT32",
28                                 "buffer": 0,
29                                 "name": "inputXTensor",
30                                 "quantization": {
31                                     "min": [ 0.0 ],
32                                     "max": [ 255.0 ],
33                                     "scale": [ 1.0 ],
34                                     "zero_point": [ 0 ],
35                                 }
36                             },
37                             {
38                                 "shape": )" + inputYShape + R"(,
39                                 "type": "FLOAT32",
40                                 "buffer": 1,
41                                 "name": "inputYTensor",
42                                 "quantization": {
43                                     "min": [ 0.0 ],
44                                     "max": [ 255.0 ],
45                                     "scale": [ 1.0 ],
46                                     "zero_point": [ 0 ],
47                                 }
48                             },
49                             {
50                                 "shape": )" + outputShape + R"(,
51                                 "type": "FLOAT32",
52                                 "buffer": 2,
53                                 "name": "outputTensor",
54                                 "quantization": {
55                                     "min": [ 0.0 ],
56                                     "max": [ 255.0 ],
57                                     "scale": [ 1.0 ],
58                                     "zero_point": [ 0 ],
59                                 }
60                             }
61                         ],
62                         "inputs": [ 0, 1 ],
63                         "outputs": [ 2 ],
64                         "operators": [
65                             {
66                                 "opcode_index": 0,
67                                 "inputs": [ 0 , 1 ],
68                                 "outputs": [ 2 ],
69                                 "builtin_options_type": "BatchMatMulOptions",
70                                 "builtin_options": {
71                                     adj_x: )" + tranX + R"(,
72                                     adj_y: )" + tranY + R"(,
73                                     "asymmetric_quantize_inputs": false
74                                 },
75                                 "custom_options_format": "FLEXBUFFERS"
76                             }
77                         ]
78                     }
79                 ],
80                 "buffers": [{},{}]
81             }
82         )";
83         Setup();
84     }
85 };
86 
87 struct BatchMatMulParamsFixture : BatchMatMulFixture
88 {
BatchMatMulParamsFixtureBatchMatMulParamsFixture89     BatchMatMulParamsFixture()
90         : BatchMatMulFixture("[ 1, 3, 3 ]",
91                              "[ 1, 3, 3 ]",
92                              "[ 1, 3, 3 ]",
93                              "false",
94                              "true")
95     {}
96 };
97 
98 TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
99 {
100     RunTest<3, armnn::DataType::Float32>(
101         0,
102         {{"inputXTensor", {2.0f, 3.0f, 5.0f,
103                            8.0f, 13.0f, 21.0f,
104                            34.0f, 55.0f, 89.0f}},
105          {"inputYTensor", {0.0f, 1.0f, 1.0f,
106                            1.0f, 0.0f, 1.0f,
107                            1.0f, 1.0f, 0.0f}}},
108         {{"outputTensor", {8.0f, 7.0f, 5.0f,
109                            34.0f, 29.0f, 21.0f,
110                            144.0f, 123.0f, 89.0f}}}
111         );
112 }
113 
114 }