xref: /aosp_15_r20/external/armnn/src/armnnUtils/TensorIOUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17*89c4ff92SAndroid Build Coastguard Worker inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18*89c4ff92SAndroid Build Coastguard Worker                                             const std::vector<TContainer>& inputDataContainers)
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     const size_t numInputs = inputBindings.size();
23*89c4ff92SAndroid Build Coastguard Worker     if (numInputs != inputDataContainers.size())
24*89c4ff92SAndroid Build Coastguard Worker     {
25*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26*89c4ff92SAndroid Build Coastguard Worker                                            "tensor data containers: {0} != {1}",
27*89c4ff92SAndroid Build Coastguard Worker                                            numInputs,
28*89c4ff92SAndroid Build Coastguard Worker                                            inputDataContainers.size()));
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < numInputs; i++)
32*89c4ff92SAndroid Build Coastguard Worker     {
33*89c4ff92SAndroid Build Coastguard Worker         const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34*89c4ff92SAndroid Build Coastguard Worker         const TContainer& inputData = inputDataContainers[i];
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker         mapbox::util::apply_visitor([&](auto&& value)
37*89c4ff92SAndroid Build Coastguard Worker         {
38*89c4ff92SAndroid Build Coastguard Worker             if (value.size() != inputBinding.second.GetNumElements())
39*89c4ff92SAndroid Build Coastguard Worker             {
40*89c4ff92SAndroid Build Coastguard Worker                throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41*89c4ff92SAndroid Build Coastguard Worker                                                   inputBinding.second.GetNumElements(),
42*89c4ff92SAndroid Build Coastguard Worker                                                   value.size()));
43*89c4ff92SAndroid Build Coastguard Worker             }
44*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo inputTensorInfo = inputBinding.second;
45*89c4ff92SAndroid Build Coastguard Worker             inputTensorInfo.SetConstant(true);
46*89c4ff92SAndroid Build Coastguard Worker             armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47*89c4ff92SAndroid Build Coastguard Worker             inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48*89c4ff92SAndroid Build Coastguard Worker         },
49*89c4ff92SAndroid Build Coastguard Worker         inputData);
50*89c4ff92SAndroid Build Coastguard Worker     }
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     return inputTensors;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)56*89c4ff92SAndroid Build Coastguard Worker inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57*89c4ff92SAndroid Build Coastguard Worker                                               std::vector<TContainer>& outputDataContainers)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     const size_t numOutputs = outputBindings.size();
62*89c4ff92SAndroid Build Coastguard Worker     if (numOutputs != outputDataContainers.size())
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(fmt::format("Number of outputs does not match number"
65*89c4ff92SAndroid Build Coastguard Worker                                            "of tensor data containers: {0} != {1}",
66*89c4ff92SAndroid Build Coastguard Worker                                            numOutputs,
67*89c4ff92SAndroid Build Coastguard Worker                                            outputDataContainers.size()));
68*89c4ff92SAndroid Build Coastguard Worker     }
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < numOutputs; i++)
71*89c4ff92SAndroid Build Coastguard Worker     {
72*89c4ff92SAndroid Build Coastguard Worker         const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73*89c4ff92SAndroid Build Coastguard Worker         TContainer& outputData = outputDataContainers[i];
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker         mapbox::util::apply_visitor([&](auto&& value)
76*89c4ff92SAndroid Build Coastguard Worker         {
77*89c4ff92SAndroid Build Coastguard Worker             if (value.size() != outputBinding.second.GetNumElements())
78*89c4ff92SAndroid Build Coastguard Worker             {
79*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("Output tensor has incorrect size");
80*89c4ff92SAndroid Build Coastguard Worker             }
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker             armnn::Tensor outputTensor(outputBinding.second, value.data());
83*89c4ff92SAndroid Build Coastguard Worker             outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84*89c4ff92SAndroid Build Coastguard Worker         },
85*89c4ff92SAndroid Build Coastguard Worker         outputData);
86*89c4ff92SAndroid Build Coastguard Worker     }
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     return outputTensors;
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
92