1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 #include "../TfLiteParser.hpp" 8 #include <sstream> 9 10 using armnnTfLiteParser::TfLiteParserImpl; 11 12 TEST_SUITE("TensorflowLiteParser_GetBuffer") 13 { 14 struct GetBufferFixture : public ParserFlatbuffersFixture 15 { GetBufferFixtureGetBufferFixture16 explicit GetBufferFixture() 17 { 18 m_JsonString = R"( 19 { 20 "version": 3, 21 "operator_codes": [ { "builtin_code": "CONV_2D" } ], 22 "subgraphs": [ { 23 "tensors": [ 24 { 25 "shape": [ 1, 3, 3, 1 ], 26 "type": "UINT8", 27 "buffer": 0, 28 "name": "inputTensor", 29 "quantization": { 30 "min": [ 0.0 ], 31 "max": [ 255.0 ], 32 "scale": [ 1.0 ], 33 "zero_point": [ 0 ], 34 } 35 }, 36 { 37 "shape": [ 1, 1, 1, 1 ], 38 "type": "UINT8", 39 "buffer": 1, 40 "name": "outputTensor", 41 "quantization": { 42 "min": [ 0.0 ], 43 "max": [ 511.0 ], 44 "scale": [ 2.0 ], 45 "zero_point": [ 0 ], 46 } 47 }, 48 { 49 "shape": [ 1, 3, 3, 1 ], 50 "type": "UINT8", 51 "buffer": 2, 52 "name": "filterTensor", 53 "quantization": { 54 "min": [ 0.0 ], 55 "max": [ 255.0 ], 56 "scale": [ 1.0 ], 57 "zero_point": [ 0 ], 58 } 59 } 60 ], 61 "inputs": [ 0 ], 62 "outputs": [ 1 ], 63 "operators": [ 64 { 65 "opcode_index": 0, 66 "inputs": [ 0, 2 ], 67 "outputs": [ 1 ], 68 "builtin_options_type": "Conv2DOptions", 69 "builtin_options": { 70 "padding": "VALID", 71 "stride_w": 1, 72 "stride_h": 1, 73 "fused_activation_function": "NONE" 74 }, 75 "custom_options_format": "FLEXBUFFERS" 76 } 77 ], 78 } ], 79 "buffers" : [ 80 { }, 81 { }, 82 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], }, 83 { }, 84 ] 85 } 86 )"; 87 ReadStringToBinary(); 88 } 89 CheckBufferContentsGetBufferFixture90 void CheckBufferContents(const TfLiteParserImpl::ModelPtr& model, 91 std::vector<int32_t> bufferValues, size_t bufferIndex) 92 { 93 for(long unsigned int i=0; i<bufferValues.size(); i++) 94 { 95 CHECK_EQ(TfLiteParserImpl::GetBuffer(model, bufferIndex)->data[i], bufferValues[i]); 96 } 97 } 98 }; 99 100 TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckContents") 101 { 102 //Check contents of buffer are correct 103 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 104 m_GraphBinary.size()); 105 std::vector<int32_t> bufferValues = {2,1,0,6,2,1,4,1,2}; 106 CheckBufferContents(model, bufferValues, 2); 107 } 108 109 TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckEmpty") 110 { 111 //Check if test fixture buffers are empty or not 112 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 113 m_GraphBinary.size()); 114 CHECK(TfLiteParserImpl::GetBuffer(model, 0)->data.empty()); 115 CHECK(TfLiteParserImpl::GetBuffer(model, 1)->data.empty()); 116 CHECK(!TfLiteParserImpl::GetBuffer(model, 2)->data.empty()); 117 CHECK(TfLiteParserImpl::GetBuffer(model, 3)->data.empty()); 118 } 119 120 TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckParseException") 121 { 122 //Check if armnn::ParseException thrown when invalid buffer index used 123 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 124 m_GraphBinary.size()); 125 CHECK_THROWS_AS(TfLiteParserImpl::GetBuffer(model, 4), armnn::Exception); 126 } 127 128 } 129