1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "CommonTestUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
11*89c4ff92SAndroid Build Coastguard Worker
CreateInputsFrom(Layer * layer,std::vector<unsigned int> ignoreSlots)12*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots CreateInputsFrom(Layer* layer,
13*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots)
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots result;
16*89c4ff92SAndroid Build Coastguard Worker for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker if (std::find(ignoreSlots.begin(), ignoreSlots.end(), it->GetSlotIndex()) != ignoreSlots.end())
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker continue;
21*89c4ff92SAndroid Build Coastguard Worker }
22*89c4ff92SAndroid Build Coastguard Worker else
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker result.push_back(&(*it));
25*89c4ff92SAndroid Build Coastguard Worker }
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker return result;
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker // ignoreSlots assumes you want to ignore the same slots all on layers within the vector
CreateInputsFrom(const std::vector<Layer * > & layers,std::vector<unsigned int> ignoreSlots)31*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers,
32*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker SubgraphView::InputSlots result;
35*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer: layers)
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker if (std::find(ignoreSlots.begin(), ignoreSlots.end(), it->GetSlotIndex()) != ignoreSlots.end())
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker continue;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker else
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker result.push_back(&(*it));
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker return result;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
CreateOutputsFrom(const std::vector<Layer * > & layers)52*89c4ff92SAndroid Build Coastguard Worker SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker SubgraphView::OutputSlots result;
55*89c4ff92SAndroid Build Coastguard Worker for (auto && layer : layers)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker result.push_back(&(*it));
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker return result;
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker
CreateSubgraphViewFrom(SubgraphView::InputSlots && inputs,SubgraphView::OutputSlots && outputs,SubgraphView::Layers && layers)65*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr CreateSubgraphViewFrom(SubgraphView::InputSlots&& inputs,
66*89c4ff92SAndroid Build Coastguard Worker SubgraphView::OutputSlots&& outputs,
67*89c4ff92SAndroid Build Coastguard Worker SubgraphView::Layers&& layers)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<SubgraphView>(std::move(inputs), std::move(outputs), std::move(layers));
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
CreateBackendObject(const armnn::BackendId & backendId)72*89c4ff92SAndroid Build Coastguard Worker armnn::IBackendInternalUniquePtr CreateBackendObject(const armnn::BackendId& backendId)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker auto& backendRegistry = BackendRegistryInstance();
75*89c4ff92SAndroid Build Coastguard Worker auto backendFactory = backendRegistry.GetFactory(backendId);
76*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = backendFactory();
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker return backendObjPtr;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker
MakeTensorShape(unsigned int batches,unsigned int channels,unsigned int height,unsigned int width,armnn::DataLayout layout)81*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape MakeTensorShape(unsigned int batches,
82*89c4ff92SAndroid Build Coastguard Worker unsigned int channels,
83*89c4ff92SAndroid Build Coastguard Worker unsigned int height,
84*89c4ff92SAndroid Build Coastguard Worker unsigned int width,
85*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout layout)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
88*89c4ff92SAndroid Build Coastguard Worker switch (layout)
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
91*89c4ff92SAndroid Build Coastguard Worker return TensorShape{ batches, channels, height, width };
92*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
93*89c4ff92SAndroid Build Coastguard Worker return TensorShape{ batches, height, width, channels };
94*89c4ff92SAndroid Build Coastguard Worker default:
95*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(std::string("Unsupported data layout: ") + GetDataLayoutName(layout));
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker }
98