1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "../OnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include <onnx/onnx.pb.h> 9 #include "google/protobuf/stubs/logging.h" 10 11 using ModelPtr = std::unique_ptr<onnx::ModelProto>; 12 13 TEST_SUITE("OnnxParser_GetInputsOutputs") 14 { 15 struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 16 { GetInputsOutputsMainFixtureGetInputsOutputsMainFixture17 explicit GetInputsOutputsMainFixture() 18 { 19 m_Prototext = R"( 20 ir_version: 3 21 producer_name: "CNTK" 22 producer_version: "2.5.1" 23 domain: "ai.cntk" 24 model_version: 1 25 graph { 26 name: "CNTKGraph" 27 input { 28 name: "Input" 29 type { 30 tensor_type { 31 elem_type: 1 32 shape { 33 dim { 34 dim_value: 4 35 } 36 } 37 } 38 } 39 } 40 node { 41 input: "Input" 42 output: "Output" 43 name: "ActivationLayer" 44 op_type: "Relu" 45 } 46 output { 47 name: "Output" 48 type { 49 tensor_type { 50 elem_type: 1 51 shape { 52 dim { 53 dim_value: 4 54 } 55 } 56 } 57 } 58 } 59 } 60 opset_import { 61 version: 7 62 })"; 63 Setup(); 64 } 65 }; 66 67 68 TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetInput") 69 { 70 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 71 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 72 CHECK_EQ(1, tensors.size()); 73 CHECK_EQ("Input", tensors[0]); 74 75 } 76 77 TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetOutput") 78 { 79 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 80 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model); 81 CHECK_EQ(1, tensors.size()); 82 CHECK_EQ("Output", tensors[0]); 83 } 84 85 struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 86 { GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture87 GetEmptyInputsOutputsFixture() 88 { 89 m_Prototext = R"( 90 ir_version: 3 91 producer_name: "CNTK " 92 producer_version: "2.5.1 " 93 domain: "ai.cntk " 94 model_version: 1 95 graph { 96 name: "CNTKGraph " 97 node { 98 output: "Output" 99 attribute { 100 name: "value" 101 t { 102 dims: 7 103 data_type: 1 104 float_data: 0.0 105 float_data: 1.0 106 float_data: 2.0 107 float_data: 3.0 108 float_data: 4.0 109 float_data: 5.0 110 float_data: 6.0 111 112 } 113 type: 1 114 } 115 name: "constantNode" 116 op_type: "Constant" 117 } 118 output { 119 name: "Output" 120 type { 121 tensor_type { 122 elem_type: 1 123 shape { 124 dim { 125 dim_value: 7 126 } 127 } 128 } 129 } 130 } 131 } 132 opset_import { 133 version: 7 134 })"; 135 Setup(); 136 } 137 }; 138 139 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs") 140 { 141 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 142 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 143 CHECK_EQ(0, tensors.size()); 144 } 145 146 TEST_CASE("GetInputsNullModel") 147 { 148 CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException); 149 } 150 151 TEST_CASE("GetOutputsNullModel") 152 { 153 auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf 154 CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException); 155 } 156 157 struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 158 { GetInputsMultipleFixtureGetInputsMultipleFixture159 GetInputsMultipleFixture() { 160 161 m_Prototext = R"( 162 ir_version: 3 163 producer_name: "CNTK" 164 producer_version: "2.5.1" 165 domain: "ai.cntk" 166 model_version: 1 167 graph { 168 name: "CNTKGraph" 169 input { 170 name: "Input0" 171 type { 172 tensor_type { 173 elem_type: 1 174 shape { 175 dim { 176 dim_value: 1 177 } 178 dim { 179 dim_value: 1 180 } 181 dim { 182 dim_value: 1 183 } 184 dim { 185 dim_value: 4 186 } 187 } 188 } 189 } 190 } 191 input { 192 name: "Input1" 193 type { 194 tensor_type { 195 elem_type: 1 196 shape { 197 dim { 198 dim_value: 4 199 } 200 } 201 } 202 } 203 } 204 node { 205 input: "Input0" 206 input: "Input1" 207 output: "Output" 208 name: "addition" 209 op_type: "Add" 210 doc_string: "" 211 domain: "" 212 } 213 output { 214 name: "Output" 215 type { 216 tensor_type { 217 elem_type: 1 218 shape { 219 dim { 220 dim_value: 1 221 } 222 dim { 223 dim_value: 1 224 } 225 dim { 226 dim_value: 1 227 } 228 dim { 229 dim_value: 4 230 } 231 } 232 } 233 } 234 } 235 } 236 opset_import { 237 version: 7 238 })"; 239 Setup(); 240 } 241 }; 242 243 TEST_CASE_FIXTURE(GetInputsMultipleFixture, "GetInputsMultipleInputs") 244 { 245 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 246 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 247 CHECK_EQ(2, tensors.size()); 248 CHECK_EQ("Input0", tensors[0]); 249 CHECK_EQ("Input1", tensors[1]); 250 } 251 252 } 253