1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/Tensor.hpp>
9
10 #include <fmt/format.h>
11 #include <mapbox/variant.hpp>
12
13 namespace armnnUtils
14 {
15
16 template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17 inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18 const std::vector<TContainer>& inputDataContainers)
19 {
20 armnn::InputTensors inputTensors;
21
22 const size_t numInputs = inputBindings.size();
23 if (numInputs != inputDataContainers.size())
24 {
25 throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26 "tensor data containers: {0} != {1}",
27 numInputs,
28 inputDataContainers.size()));
29 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
36 mapbox::util::apply_visitor([&](auto&& value)
37 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
40 throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41 inputBinding.second.GetNumElements(),
42 value.size()));
43 }
44 armnn::TensorInfo inputTensorInfo = inputBinding.second;
45 inputTensorInfo.SetConstant(true);
46 armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48 },
49 inputData);
50 }
51
52 return inputTensors;
53 }
54
55 template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)56 inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57 std::vector<TContainer>& outputDataContainers)
58 {
59 armnn::OutputTensors outputTensors;
60
61 const size_t numOutputs = outputBindings.size();
62 if (numOutputs != outputDataContainers.size())
63 {
64 throw armnn::Exception(fmt::format("Number of outputs does not match number"
65 "of tensor data containers: {0} != {1}",
66 numOutputs,
67 outputDataContainers.size()));
68 }
69
70 for (size_t i = 0; i < numOutputs; i++)
71 {
72 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73 TContainer& outputData = outputDataContainers[i];
74
75 mapbox::util::apply_visitor([&](auto&& value)
76 {
77 if (value.size() != outputBinding.second.GetNumElements())
78 {
79 throw armnn::Exception("Output tensor has incorrect size");
80 }
81
82 armnn::Tensor outputTensor(outputBinding.second, value.data());
83 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84 },
85 outputData);
86 }
87
88 return outputTensors;
89 }
90
91 } // namespace armnnUtils
92