xref: /aosp_15_r20/external/armnn/delegate/classic/src/MultiLayerFacade.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021,2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker // NOTE: the MultiLayerFacade class is a utility class which makes a chain
9*89c4ff92SAndroid Build Coastguard Worker //       of operators look like a single IConnectableLayer with the first
10*89c4ff92SAndroid Build Coastguard Worker //       layer in the chain supplying the input slots and the last supplying
11*89c4ff92SAndroid Build Coastguard Worker //       the output slots. It enables us, for example, to simulate a
12*89c4ff92SAndroid Build Coastguard Worker //       Tensorflow Lite FloorDiv operator by chaining a Div layer followed
13*89c4ff92SAndroid Build Coastguard Worker //       by a Floor layer and pass them as a single unit to the code that
14*89c4ff92SAndroid Build Coastguard Worker //       connects up the graph as the delegate proceeds to build up the
15*89c4ff92SAndroid Build Coastguard Worker //       Arm NN subgraphs.
16*89c4ff92SAndroid Build Coastguard Worker //
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker #include <common/include/ProfilingGuid.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker class MultiLayerFacade : public armnn::IConnectableLayer
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker public:
MultiLayerFacade()27*89c4ff92SAndroid Build Coastguard Worker     MultiLayerFacade() :
28*89c4ff92SAndroid Build Coastguard Worker         m_FirstLayer(nullptr), m_LastLayer(nullptr) {}
29*89c4ff92SAndroid Build Coastguard Worker 
MultiLayerFacade(armnn::IConnectableLayer * firstLayer,armnn::IConnectableLayer * lastLayer)30*89c4ff92SAndroid Build Coastguard Worker     MultiLayerFacade(armnn::IConnectableLayer* firstLayer, armnn::IConnectableLayer* lastLayer) :
31*89c4ff92SAndroid Build Coastguard Worker         m_FirstLayer(firstLayer), m_LastLayer(lastLayer) {}
32*89c4ff92SAndroid Build Coastguard Worker 
MultiLayerFacade(const MultiLayerFacade & obj)33*89c4ff92SAndroid Build Coastguard Worker     MultiLayerFacade(const MultiLayerFacade& obj) :
34*89c4ff92SAndroid Build Coastguard Worker         m_FirstLayer(obj.m_FirstLayer), m_LastLayer(obj.m_LastLayer) {}
35*89c4ff92SAndroid Build Coastguard Worker 
~MultiLayerFacade()36*89c4ff92SAndroid Build Coastguard Worker     ~MultiLayerFacade() {} // we don't own the pointers
37*89c4ff92SAndroid Build Coastguard Worker 
operator =(const MultiLayerFacade & obj)38*89c4ff92SAndroid Build Coastguard Worker     MultiLayerFacade& operator=(const MultiLayerFacade& obj)
39*89c4ff92SAndroid Build Coastguard Worker     {
40*89c4ff92SAndroid Build Coastguard Worker         m_FirstLayer = obj.m_FirstLayer;
41*89c4ff92SAndroid Build Coastguard Worker         m_LastLayer = obj.m_LastLayer;
42*89c4ff92SAndroid Build Coastguard Worker         return *this;
43*89c4ff92SAndroid Build Coastguard Worker     }
44*89c4ff92SAndroid Build Coastguard Worker 
AssignValues(armnn::IConnectableLayer * firstLayer,armnn::IConnectableLayer * lastLayer)45*89c4ff92SAndroid Build Coastguard Worker     void AssignValues(armnn::IConnectableLayer* firstLayer, armnn::IConnectableLayer* lastLayer)
46*89c4ff92SAndroid Build Coastguard Worker     {
47*89c4ff92SAndroid Build Coastguard Worker         m_FirstLayer = firstLayer;
48*89c4ff92SAndroid Build Coastguard Worker         m_LastLayer = lastLayer;
49*89c4ff92SAndroid Build Coastguard Worker     }
50*89c4ff92SAndroid Build Coastguard Worker 
GetName() const51*89c4ff92SAndroid Build Coastguard Worker     virtual const char* GetName() const override
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetName();
54*89c4ff92SAndroid Build Coastguard Worker     }
55*89c4ff92SAndroid Build Coastguard Worker 
GetNumInputSlots() const56*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetNumInputSlots() const override
57*89c4ff92SAndroid Build Coastguard Worker     {
58*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetNumInputSlots();
59*89c4ff92SAndroid Build Coastguard Worker     }
60*89c4ff92SAndroid Build Coastguard Worker 
GetNumOutputSlots() const61*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetNumOutputSlots() const override
62*89c4ff92SAndroid Build Coastguard Worker     {
63*89c4ff92SAndroid Build Coastguard Worker         return m_LastLayer->GetNumOutputSlots();
64*89c4ff92SAndroid Build Coastguard Worker     }
65*89c4ff92SAndroid Build Coastguard Worker 
GetInputSlot(unsigned int index) const66*89c4ff92SAndroid Build Coastguard Worker     virtual const armnn::IInputSlot& GetInputSlot(unsigned int index) const override
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetInputSlot(index);
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker 
GetInputSlot(unsigned int index)71*89c4ff92SAndroid Build Coastguard Worker     virtual armnn::IInputSlot& GetInputSlot(unsigned int index) override
72*89c4ff92SAndroid Build Coastguard Worker     {
73*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetInputSlot(index);
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker 
GetOutputSlot(unsigned int index) const76*89c4ff92SAndroid Build Coastguard Worker     virtual const armnn::IOutputSlot& GetOutputSlot(unsigned int index) const override
77*89c4ff92SAndroid Build Coastguard Worker     {
78*89c4ff92SAndroid Build Coastguard Worker         return m_LastLayer->GetOutputSlot(index);
79*89c4ff92SAndroid Build Coastguard Worker     }
80*89c4ff92SAndroid Build Coastguard Worker 
GetOutputSlot(unsigned int index)81*89c4ff92SAndroid Build Coastguard Worker     virtual armnn::IOutputSlot& GetOutputSlot(unsigned int index) override
82*89c4ff92SAndroid Build Coastguard Worker     {
83*89c4ff92SAndroid Build Coastguard Worker         return m_LastLayer->GetOutputSlot(index);
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker 
InferOutputShapes(const std::vector<armnn::TensorShape> & inputShapes) const86*89c4ff92SAndroid Build Coastguard Worker     virtual std::vector<armnn::TensorShape> InferOutputShapes(
87*89c4ff92SAndroid Build Coastguard Worker         const std::vector<armnn::TensorShape>& inputShapes) const override
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         // NOTE: do not expect this function to be used. Likely that if it is it might need to be overridden
90*89c4ff92SAndroid Build Coastguard Worker         //       for particular sequences of operators.
91*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->InferOutputShapes(inputShapes);
92*89c4ff92SAndroid Build Coastguard Worker     }
93*89c4ff92SAndroid Build Coastguard Worker 
GetGuid() const94*89c4ff92SAndroid Build Coastguard Worker     virtual LayerGuid GetGuid() const override
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetGuid();
97*89c4ff92SAndroid Build Coastguard Worker     }
98*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(armnn::IStrategy & strategy) const99*89c4ff92SAndroid Build Coastguard Worker     virtual void ExecuteStrategy(armnn::IStrategy& strategy) const override
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         // Do not expect this function to be used so not providing an implementation
102*89c4ff92SAndroid Build Coastguard Worker         // if an implementation is required and the chain contains more than two operators
103*89c4ff92SAndroid Build Coastguard Worker         // would have to provide a way to record the intermediate layers so they could be
104*89c4ff92SAndroid Build Coastguard Worker         // visited... the same applies to the BackendSelectionHint
105*89c4ff92SAndroid Build Coastguard Worker         // below.
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker 
BackendSelectionHint(armnn::Optional<armnn::BackendId> backend)108*89c4ff92SAndroid Build Coastguard Worker     virtual void BackendSelectionHint(armnn::Optional<armnn::BackendId> backend) override
109*89c4ff92SAndroid Build Coastguard Worker     {
110*89c4ff92SAndroid Build Coastguard Worker         // Do not expect this function to be used so not providing an implementation
111*89c4ff92SAndroid Build Coastguard Worker     }
112*89c4ff92SAndroid Build Coastguard Worker 
GetType() const113*89c4ff92SAndroid Build Coastguard Worker     virtual armnn::LayerType GetType() const override
114*89c4ff92SAndroid Build Coastguard Worker     {
115*89c4ff92SAndroid Build Coastguard Worker         return m_FirstLayer->GetType();
116*89c4ff92SAndroid Build Coastguard Worker     }
117*89c4ff92SAndroid Build Coastguard Worker 
GetParameters() const118*89c4ff92SAndroid Build Coastguard Worker     virtual const armnn::BaseDescriptor& GetParameters() const override { return m_NullDescriptor; }
119*89c4ff92SAndroid Build Coastguard Worker 
SetBackendId(const armnn::BackendId & id)120*89c4ff92SAndroid Build Coastguard Worker     void SetBackendId(const armnn::BackendId& id) override {}
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker protected:
123*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve the handles to the constant values stored by the layer.
124*89c4ff92SAndroid Build Coastguard Worker     /// @return A vector of the constant tensors stored by this layer.
GetConstantTensorsByRef()125*89c4ff92SAndroid Build Coastguard Worker     ConstantTensors GetConstantTensorsByRef() override { return {}; }
GetConstantTensorsByRef() const126*89c4ff92SAndroid Build Coastguard Worker     ImmutableConstantTensors GetConstantTensorsByRef() const override { return {}; }
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker private:
129*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* m_FirstLayer;
130*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* m_LastLayer;
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker     // to satisfy the GetParameters method need to hand back a NullDescriptor
133*89c4ff92SAndroid Build Coastguard Worker     armnn::NullDescriptor m_NullDescriptor;
134*89c4ff92SAndroid Build Coastguard Worker };
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
137