xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/CommonTestUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "CommonTestUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
11*89c4ff92SAndroid Build Coastguard Worker 
CreateInputsFrom(Layer * layer,std::vector<unsigned int> ignoreSlots)12*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots CreateInputsFrom(Layer* layer,
13*89c4ff92SAndroid Build Coastguard Worker                                           std::vector<unsigned int> ignoreSlots)
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::InputSlots result;
16*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         if (std::find(ignoreSlots.begin(), ignoreSlots.end(), it->GetSlotIndex()) != ignoreSlots.end())
19*89c4ff92SAndroid Build Coastguard Worker         {
20*89c4ff92SAndroid Build Coastguard Worker             continue;
21*89c4ff92SAndroid Build Coastguard Worker         }
22*89c4ff92SAndroid Build Coastguard Worker         else
23*89c4ff92SAndroid Build Coastguard Worker         {
24*89c4ff92SAndroid Build Coastguard Worker             result.push_back(&(*it));
25*89c4ff92SAndroid Build Coastguard Worker         }
26*89c4ff92SAndroid Build Coastguard Worker     }
27*89c4ff92SAndroid Build Coastguard Worker         return result;
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker // ignoreSlots assumes you want to ignore the same slots all on layers within the vector
CreateInputsFrom(const std::vector<Layer * > & layers,std::vector<unsigned int> ignoreSlots)31*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers,
32*89c4ff92SAndroid Build Coastguard Worker                                           std::vector<unsigned int> ignoreSlots)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::InputSlots result;
35*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer: layers)
36*89c4ff92SAndroid Build Coastguard Worker     {
37*89c4ff92SAndroid Build Coastguard Worker         for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
38*89c4ff92SAndroid Build Coastguard Worker         {
39*89c4ff92SAndroid Build Coastguard Worker             if (std::find(ignoreSlots.begin(), ignoreSlots.end(), it->GetSlotIndex()) != ignoreSlots.end())
40*89c4ff92SAndroid Build Coastguard Worker             {
41*89c4ff92SAndroid Build Coastguard Worker                 continue;
42*89c4ff92SAndroid Build Coastguard Worker             }
43*89c4ff92SAndroid Build Coastguard Worker             else
44*89c4ff92SAndroid Build Coastguard Worker             {
45*89c4ff92SAndroid Build Coastguard Worker                 result.push_back(&(*it));
46*89c4ff92SAndroid Build Coastguard Worker             }
47*89c4ff92SAndroid Build Coastguard Worker         }
48*89c4ff92SAndroid Build Coastguard Worker     }
49*89c4ff92SAndroid Build Coastguard Worker     return result;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker 
CreateOutputsFrom(const std::vector<Layer * > & layers)52*89c4ff92SAndroid Build Coastguard Worker SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::OutputSlots result;
55*89c4ff92SAndroid Build Coastguard Worker     for (auto && layer : layers)
56*89c4ff92SAndroid Build Coastguard Worker     {
57*89c4ff92SAndroid Build Coastguard Worker         for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
58*89c4ff92SAndroid Build Coastguard Worker         {
59*89c4ff92SAndroid Build Coastguard Worker             result.push_back(&(*it));
60*89c4ff92SAndroid Build Coastguard Worker         }
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker     return result;
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker 
CreateSubgraphViewFrom(SubgraphView::InputSlots && inputs,SubgraphView::OutputSlots && outputs,SubgraphView::Layers && layers)65*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr CreateSubgraphViewFrom(SubgraphView::InputSlots&& inputs,
66*89c4ff92SAndroid Build Coastguard Worker                                                      SubgraphView::OutputSlots&& outputs,
67*89c4ff92SAndroid Build Coastguard Worker                                                      SubgraphView::Layers&& layers)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<SubgraphView>(std::move(inputs), std::move(outputs), std::move(layers));
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendObject(const armnn::BackendId & backendId)72*89c4ff92SAndroid Build Coastguard Worker armnn::IBackendInternalUniquePtr CreateBackendObject(const armnn::BackendId& backendId)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker     auto& backendRegistry = BackendRegistryInstance();
75*89c4ff92SAndroid Build Coastguard Worker     auto  backendFactory  = backendRegistry.GetFactory(backendId);
76*89c4ff92SAndroid Build Coastguard Worker     auto  backendObjPtr   = backendFactory();
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     return backendObjPtr;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker 
MakeTensorShape(unsigned int batches,unsigned int channels,unsigned int height,unsigned int width,armnn::DataLayout layout)81*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape MakeTensorShape(unsigned int batches,
82*89c4ff92SAndroid Build Coastguard Worker                                    unsigned int channels,
83*89c4ff92SAndroid Build Coastguard Worker                                    unsigned int height,
84*89c4ff92SAndroid Build Coastguard Worker                                    unsigned int width,
85*89c4ff92SAndroid Build Coastguard Worker                                    armnn::DataLayout layout)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
88*89c4ff92SAndroid Build Coastguard Worker     switch (layout)
89*89c4ff92SAndroid Build Coastguard Worker     {
90*89c4ff92SAndroid Build Coastguard Worker         case DataLayout::NCHW:
91*89c4ff92SAndroid Build Coastguard Worker             return TensorShape{ batches, channels, height, width };
92*89c4ff92SAndroid Build Coastguard Worker         case DataLayout::NHWC:
93*89c4ff92SAndroid Build Coastguard Worker             return TensorShape{ batches, height, width, channels };
94*89c4ff92SAndroid Build Coastguard Worker         default:
95*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(std::string("Unsupported data layout: ") + GetDataLayoutName(layout));
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker }
98