xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/BatchMatMul.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_BatchMatMul")
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker struct BatchMatMulFixture : public ParserFlatbuffersFixture
11*89c4ff92SAndroid Build Coastguard Worker {
BatchMatMulFixtureBatchMatMulFixture12*89c4ff92SAndroid Build Coastguard Worker     explicit BatchMatMulFixture(const std::string &inputXShape,
13*89c4ff92SAndroid Build Coastguard Worker                                 const std::string &inputYShape,
14*89c4ff92SAndroid Build Coastguard Worker                                 const std::string &outputShape,
15*89c4ff92SAndroid Build Coastguard Worker                                 const std::string &tranX,
16*89c4ff92SAndroid Build Coastguard Worker                                 const std::string &tranY)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
19*89c4ff92SAndroid Build Coastguard Worker             {
20*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
21*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ],
22*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [
23*89c4ff92SAndroid Build Coastguard Worker                     {
24*89c4ff92SAndroid Build Coastguard Worker                         "tensors": [
25*89c4ff92SAndroid Build Coastguard Worker                             {
26*89c4ff92SAndroid Build Coastguard Worker                                 "shape": )" + inputXShape + R"(,
27*89c4ff92SAndroid Build Coastguard Worker                                 "type": "FLOAT32",
28*89c4ff92SAndroid Build Coastguard Worker                                 "buffer": 0,
29*89c4ff92SAndroid Build Coastguard Worker                                 "name": "inputXTensor",
30*89c4ff92SAndroid Build Coastguard Worker                                 "quantization": {
31*89c4ff92SAndroid Build Coastguard Worker                                     "min": [ 0.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                     "max": [ 255.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                     "scale": [ 1.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                     "zero_point": [ 0 ],
35*89c4ff92SAndroid Build Coastguard Worker                                 }
36*89c4ff92SAndroid Build Coastguard Worker                             },
37*89c4ff92SAndroid Build Coastguard Worker                             {
38*89c4ff92SAndroid Build Coastguard Worker                                 "shape": )" + inputYShape + R"(,
39*89c4ff92SAndroid Build Coastguard Worker                                 "type": "FLOAT32",
40*89c4ff92SAndroid Build Coastguard Worker                                 "buffer": 1,
41*89c4ff92SAndroid Build Coastguard Worker                                 "name": "inputYTensor",
42*89c4ff92SAndroid Build Coastguard Worker                                 "quantization": {
43*89c4ff92SAndroid Build Coastguard Worker                                     "min": [ 0.0 ],
44*89c4ff92SAndroid Build Coastguard Worker                                     "max": [ 255.0 ],
45*89c4ff92SAndroid Build Coastguard Worker                                     "scale": [ 1.0 ],
46*89c4ff92SAndroid Build Coastguard Worker                                     "zero_point": [ 0 ],
47*89c4ff92SAndroid Build Coastguard Worker                                 }
48*89c4ff92SAndroid Build Coastguard Worker                             },
49*89c4ff92SAndroid Build Coastguard Worker                             {
50*89c4ff92SAndroid Build Coastguard Worker                                 "shape": )" + outputShape + R"(,
51*89c4ff92SAndroid Build Coastguard Worker                                 "type": "FLOAT32",
52*89c4ff92SAndroid Build Coastguard Worker                                 "buffer": 2,
53*89c4ff92SAndroid Build Coastguard Worker                                 "name": "outputTensor",
54*89c4ff92SAndroid Build Coastguard Worker                                 "quantization": {
55*89c4ff92SAndroid Build Coastguard Worker                                     "min": [ 0.0 ],
56*89c4ff92SAndroid Build Coastguard Worker                                     "max": [ 255.0 ],
57*89c4ff92SAndroid Build Coastguard Worker                                     "scale": [ 1.0 ],
58*89c4ff92SAndroid Build Coastguard Worker                                     "zero_point": [ 0 ],
59*89c4ff92SAndroid Build Coastguard Worker                                 }
60*89c4ff92SAndroid Build Coastguard Worker                             }
61*89c4ff92SAndroid Build Coastguard Worker                         ],
62*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 0, 1 ],
63*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 2 ],
64*89c4ff92SAndroid Build Coastguard Worker                         "operators": [
65*89c4ff92SAndroid Build Coastguard Worker                             {
66*89c4ff92SAndroid Build Coastguard Worker                                 "opcode_index": 0,
67*89c4ff92SAndroid Build Coastguard Worker                                 "inputs": [ 0 , 1 ],
68*89c4ff92SAndroid Build Coastguard Worker                                 "outputs": [ 2 ],
69*89c4ff92SAndroid Build Coastguard Worker                                 "builtin_options_type": "BatchMatMulOptions",
70*89c4ff92SAndroid Build Coastguard Worker                                 "builtin_options": {
71*89c4ff92SAndroid Build Coastguard Worker                                     adj_x: )" + tranX + R"(,
72*89c4ff92SAndroid Build Coastguard Worker                                     adj_y: )" + tranY + R"(,
73*89c4ff92SAndroid Build Coastguard Worker                                     "asymmetric_quantize_inputs": false
74*89c4ff92SAndroid Build Coastguard Worker                                 },
75*89c4ff92SAndroid Build Coastguard Worker                                 "custom_options_format": "FLEXBUFFERS"
76*89c4ff92SAndroid Build Coastguard Worker                             }
77*89c4ff92SAndroid Build Coastguard Worker                         ]
78*89c4ff92SAndroid Build Coastguard Worker                     }
79*89c4ff92SAndroid Build Coastguard Worker                 ],
80*89c4ff92SAndroid Build Coastguard Worker                 "buffers": [{},{}]
81*89c4ff92SAndroid Build Coastguard Worker             }
82*89c4ff92SAndroid Build Coastguard Worker         )";
83*89c4ff92SAndroid Build Coastguard Worker         Setup();
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker struct BatchMatMulParamsFixture : BatchMatMulFixture
88*89c4ff92SAndroid Build Coastguard Worker {
BatchMatMulParamsFixtureBatchMatMulParamsFixture89*89c4ff92SAndroid Build Coastguard Worker     BatchMatMulParamsFixture()
90*89c4ff92SAndroid Build Coastguard Worker         : BatchMatMulFixture("[ 1, 3, 3 ]",
91*89c4ff92SAndroid Build Coastguard Worker                              "[ 1, 3, 3 ]",
92*89c4ff92SAndroid Build Coastguard Worker                              "[ 1, 3, 3 ]",
93*89c4ff92SAndroid Build Coastguard Worker                              "false",
94*89c4ff92SAndroid Build Coastguard Worker                              "true")
95*89c4ff92SAndroid Build Coastguard Worker     {}
96*89c4ff92SAndroid Build Coastguard Worker };
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
101*89c4ff92SAndroid Build Coastguard Worker         0,
102*89c4ff92SAndroid Build Coastguard Worker         {{"inputXTensor", {2.0f, 3.0f, 5.0f,
103*89c4ff92SAndroid Build Coastguard Worker                            8.0f, 13.0f, 21.0f,
104*89c4ff92SAndroid Build Coastguard Worker                            34.0f, 55.0f, 89.0f}},
105*89c4ff92SAndroid Build Coastguard Worker          {"inputYTensor", {0.0f, 1.0f, 1.0f,
106*89c4ff92SAndroid Build Coastguard Worker                            1.0f, 0.0f, 1.0f,
107*89c4ff92SAndroid Build Coastguard Worker                            1.0f, 1.0f, 0.0f}}},
108*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", {8.0f, 7.0f, 5.0f,
109*89c4ff92SAndroid Build Coastguard Worker                            34.0f, 29.0f, 21.0f,
110*89c4ff92SAndroid Build Coastguard Worker                            144.0f, 123.0f, 89.0f}}}
111*89c4ff92SAndroid Build Coastguard Worker         );
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker }