1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17*89c4ff92SAndroid Build Coastguard Worker inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18*89c4ff92SAndroid Build Coastguard Worker const std::vector<TContainer>& inputDataContainers)
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker const size_t numInputs = inputBindings.size();
23*89c4ff92SAndroid Build Coastguard Worker if (numInputs != inputDataContainers.size())
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26*89c4ff92SAndroid Build Coastguard Worker "tensor data containers: {0} != {1}",
27*89c4ff92SAndroid Build Coastguard Worker numInputs,
28*89c4ff92SAndroid Build Coastguard Worker inputDataContainers.size()));
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numInputs; i++)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34*89c4ff92SAndroid Build Coastguard Worker const TContainer& inputData = inputDataContainers[i];
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([&](auto&& value)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker if (value.size() != inputBinding.second.GetNumElements())
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41*89c4ff92SAndroid Build Coastguard Worker inputBinding.second.GetNumElements(),
42*89c4ff92SAndroid Build Coastguard Worker value.size()));
43*89c4ff92SAndroid Build Coastguard Worker }
44*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = inputBinding.second;
45*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
46*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48*89c4ff92SAndroid Build Coastguard Worker },
49*89c4ff92SAndroid Build Coastguard Worker inputData);
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker return inputTensors;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)56*89c4ff92SAndroid Build Coastguard Worker inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57*89c4ff92SAndroid Build Coastguard Worker std::vector<TContainer>& outputDataContainers)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker const size_t numOutputs = outputBindings.size();
62*89c4ff92SAndroid Build Coastguard Worker if (numOutputs != outputDataContainers.size())
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("Number of outputs does not match number"
65*89c4ff92SAndroid Build Coastguard Worker "of tensor data containers: {0} != {1}",
66*89c4ff92SAndroid Build Coastguard Worker numOutputs,
67*89c4ff92SAndroid Build Coastguard Worker outputDataContainers.size()));
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numOutputs; i++)
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73*89c4ff92SAndroid Build Coastguard Worker TContainer& outputData = outputDataContainers[i];
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([&](auto&& value)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker if (value.size() != outputBinding.second.GetNumElements())
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Output tensor has incorrect size");
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker armnn::Tensor outputTensor(outputBinding.second, value.data());
83*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84*89c4ff92SAndroid Build Coastguard Worker },
85*89c4ff92SAndroid Build Coastguard Worker outputData);
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker return outputTensors;
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
92