1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 TEST_SUITE("TensorflowLiteParser_BatchMatMul") 9 { 10 struct BatchMatMulFixture : public ParserFlatbuffersFixture 11 { BatchMatMulFixtureBatchMatMulFixture12 explicit BatchMatMulFixture(const std::string &inputXShape, 13 const std::string &inputYShape, 14 const std::string &outputShape, 15 const std::string &tranX, 16 const std::string &tranY) 17 { 18 m_JsonString = R"( 19 { 20 "version": 3, 21 "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ], 22 "subgraphs": [ 23 { 24 "tensors": [ 25 { 26 "shape": )" + inputXShape + R"(, 27 "type": "FLOAT32", 28 "buffer": 0, 29 "name": "inputXTensor", 30 "quantization": { 31 "min": [ 0.0 ], 32 "max": [ 255.0 ], 33 "scale": [ 1.0 ], 34 "zero_point": [ 0 ], 35 } 36 }, 37 { 38 "shape": )" + inputYShape + R"(, 39 "type": "FLOAT32", 40 "buffer": 1, 41 "name": "inputYTensor", 42 "quantization": { 43 "min": [ 0.0 ], 44 "max": [ 255.0 ], 45 "scale": [ 1.0 ], 46 "zero_point": [ 0 ], 47 } 48 }, 49 { 50 "shape": )" + outputShape + R"(, 51 "type": "FLOAT32", 52 "buffer": 2, 53 "name": "outputTensor", 54 "quantization": { 55 "min": [ 0.0 ], 56 "max": [ 255.0 ], 57 "scale": [ 1.0 ], 58 "zero_point": [ 0 ], 59 } 60 } 61 ], 62 "inputs": [ 0, 1 ], 63 "outputs": [ 2 ], 64 "operators": [ 65 { 66 "opcode_index": 0, 67 "inputs": [ 0 , 1 ], 68 "outputs": [ 2 ], 69 "builtin_options_type": "BatchMatMulOptions", 70 "builtin_options": { 71 adj_x: )" + tranX + R"(, 72 adj_y: )" + tranY + R"(, 73 "asymmetric_quantize_inputs": false 74 }, 75 "custom_options_format": "FLEXBUFFERS" 76 } 77 ] 78 } 79 ], 80 "buffers": [{},{}] 81 } 82 )"; 83 Setup(); 84 } 85 }; 86 87 struct BatchMatMulParamsFixture : BatchMatMulFixture 88 { BatchMatMulParamsFixtureBatchMatMulParamsFixture89 BatchMatMulParamsFixture() 90 : BatchMatMulFixture("[ 1, 3, 3 ]", 91 "[ 1, 3, 3 ]", 92 "[ 1, 3, 3 ]", 93 "false", 94 "true") 95 {} 96 }; 97 98 TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams") 99 { 100 RunTest<3, armnn::DataType::Float32>( 101 0, 102 {{"inputXTensor", {2.0f, 3.0f, 5.0f, 103 8.0f, 13.0f, 21.0f, 104 34.0f, 55.0f, 89.0f}}, 105 {"inputYTensor", {0.0f, 1.0f, 1.0f, 106 1.0f, 0.0f, 1.0f, 107 1.0f, 1.0f, 0.0f}}}, 108 {{"outputTensor", {8.0f, 7.0f, 5.0f, 109 34.0f, 29.0f, 21.0f, 110 144.0f, 123.0f, 89.0f}}} 111 ); 112 } 113 114 }