1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <string> 10 11 TEST_SUITE("Deserializer_Gather") 12 { 13 struct GatherFixture : public ParserFlatbuffersSerializeFixture 14 { GatherFixtureGatherFixture15 explicit GatherFixture(const std::string& inputShape, 16 const std::string& indicesShape, 17 const std::string& input1Content, 18 const std::string& outputShape, 19 const std::string& axis, 20 const std::string dataType, 21 const std::string constDataType) 22 { 23 m_JsonString = R"( 24 { 25 inputIds: [0], 26 outputIds: [3], 27 layers: [ 28 { 29 layer_type: "InputLayer", 30 layer: { 31 base: { 32 layerBindingId: 0, 33 base: { 34 index: 0, 35 layerName: "InputLayer", 36 layerType: "Input", 37 inputSlots: [{ 38 index: 0, 39 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 40 }], 41 outputSlots: [ { 42 index: 0, 43 tensorInfo: { 44 dimensions: )" + inputShape + R"(, 45 dataType: )" + dataType + R"( 46 }}] 47 } 48 }}}, 49 { 50 layer_type: "ConstantLayer", 51 layer: { 52 base: { 53 index:1, 54 layerName: "ConstantLayer", 55 layerType: "Constant", 56 outputSlots: [ { 57 index: 0, 58 tensorInfo: { 59 dimensions: )" + indicesShape + R"(, 60 dataType: "Signed32", 61 }, 62 }], 63 }, 64 input: { 65 info: { 66 dimensions: )" + indicesShape + R"(, 67 dataType: )" + dataType + R"( 68 }, 69 data_type: )" + constDataType + R"(, 70 data: { 71 data: )" + input1Content + R"(, 72 } } 73 },}, 74 { 75 layer_type: "GatherLayer", 76 layer: { 77 base: { 78 index: 2, 79 layerName: "GatherLayer", 80 layerType: "Gather", 81 inputSlots: [ 82 { 83 index: 0, 84 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 85 }, 86 { 87 index: 1, 88 connection: {sourceLayerIndex:1, outputSlotIndex:0 } 89 }], 90 outputSlots: [ { 91 index: 0, 92 tensorInfo: { 93 dimensions: )" + outputShape + R"(, 94 dataType: )" + dataType + R"( 95 96 }}]}, 97 descriptor: { 98 axis: )" + axis + R"( 99 } 100 }}, 101 { 102 layer_type: "OutputLayer", 103 layer: { 104 base:{ 105 layerBindingId: 0, 106 base: { 107 index: 3, 108 layerName: "OutputLayer", 109 layerType: "Output", 110 inputSlots: [{ 111 index: 0, 112 connection: {sourceLayerIndex:2, outputSlotIndex:0 }, 113 }], 114 outputSlots: [ { 115 index: 0, 116 tensorInfo: { 117 dimensions: )" + outputShape + R"(, 118 dataType: )" + dataType + R"( 119 }, 120 }], 121 }}}, 122 }], 123 featureVersions: { 124 weightsLayoutScheme: 1, 125 } 126 } )"; 127 128 Setup(); 129 } 130 }; 131 132 struct SimpleGatherFixtureFloat32 : GatherFixture 133 { SimpleGatherFixtureFloat32SimpleGatherFixtureFloat32134 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]", 135 "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {} 136 }; 137 138 TEST_CASE_FIXTURE(SimpleGatherFixtureFloat32, "GatherFloat32") 139 { 140 RunTest<4, armnn::DataType::Float32>(0, 141 {{"InputLayer", { 1, 2, 3, 142 4, 5, 6, 143 7, 8, 9, 144 10, 11, 12, 145 13, 14, 15, 146 16, 17, 18 }}}, 147 {{"OutputLayer", { 7, 8, 9, 148 10, 11, 12, 149 13, 14, 15, 150 16, 17, 18, 151 7, 8, 9, 152 10, 11, 12, 153 13, 14, 15, 154 16, 17, 18, 155 7, 8, 9, 156 10, 11, 12, 157 1, 2, 3, 158 4, 5, 6 }}}); 159 } 160 161 } 162 163