xref: /aosp_15_r20/external/armnn/src/armnnUtils/TensorIOUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
10 #include <fmt/format.h>
11 #include <mapbox/variant.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17 inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18                                             const std::vector<TContainer>& inputDataContainers)
19 {
20     armnn::InputTensors inputTensors;
21 
22     const size_t numInputs = inputBindings.size();
23     if (numInputs != inputDataContainers.size())
24     {
25         throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26                                            "tensor data containers: {0} != {1}",
27                                            numInputs,
28                                            inputDataContainers.size()));
29     }
30 
31     for (size_t i = 0; i < numInputs; i++)
32     {
33         const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34         const TContainer& inputData = inputDataContainers[i];
35 
36         mapbox::util::apply_visitor([&](auto&& value)
37         {
38             if (value.size() != inputBinding.second.GetNumElements())
39             {
40                throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41                                                   inputBinding.second.GetNumElements(),
42                                                   value.size()));
43             }
44             armnn::TensorInfo inputTensorInfo = inputBinding.second;
45             inputTensorInfo.SetConstant(true);
46             armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47             inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48         },
49         inputData);
50     }
51 
52     return inputTensors;
53 }
54 
55 template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)56 inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57                                               std::vector<TContainer>& outputDataContainers)
58 {
59     armnn::OutputTensors outputTensors;
60 
61     const size_t numOutputs = outputBindings.size();
62     if (numOutputs != outputDataContainers.size())
63     {
64         throw armnn::Exception(fmt::format("Number of outputs does not match number"
65                                            "of tensor data containers: {0} != {1}",
66                                            numOutputs,
67                                            outputDataContainers.size()));
68     }
69 
70     for (size_t i = 0; i < numOutputs; i++)
71     {
72         const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73         TContainer& outputData = outputDataContainers[i];
74 
75         mapbox::util::apply_visitor([&](auto&& value)
76         {
77             if (value.size() != outputBinding.second.GetNumElements())
78             {
79                 throw armnn::Exception("Output tensor has incorrect size");
80             }
81 
82             armnn::Tensor outputTensor(outputBinding.second, value.data());
83             outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84         },
85         outputData);
86     }
87 
88     return outputTensors;
89 }
90 
91 } // namespace armnnUtils
92