xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 #include <armnn/backends/TensorHandle.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/PolymorphicDowncast.hpp>
12 
13 namespace armnn
14 {
15 namespace optimizations
16 {
17 
18 static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition,       LayerType::Division,
19                                                       LayerType::Maximum,        LayerType::Minimum,
20                                                       LayerType::Multiplication, LayerType::Prelu,
21                                                       LayerType::Subtraction,    LayerType::ElementwiseBinary };
22 
23 class AddBroadcastReshapeLayerImpl
24 {
25 public:
26     /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
Run(Graph & graph,Layer & layer) const27     void Run(Graph& graph, Layer& layer) const
28     {
29         if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
30         {
31             layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
32             layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
33 
34             const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
35             const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
36 
37             if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
38             {
39                 return;
40             }
41 
42             unsigned int reshapeSlot = 1;
43             TensorInfo reshapeInfo   = inputInfo1;
44             TensorInfo inputInfo     = inputInfo0;
45 
46             if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
47             {
48                 reshapeSlot = 0;
49                 reshapeInfo = inputInfo0;
50                 inputInfo   = inputInfo1;
51             }
52 
53             uint32_t numDimensions = inputInfo.GetNumDimensions();
54 
55             std::vector<unsigned> reshapedDim;
56             for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
57             {
58                 reshapedDim.push_back(reshapeInfo.GetShape()[i]);
59             }
60 
61             std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
62             std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
63 
64             reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
65 
66             // If the parent layer is a Constant layer and it is only used once we can short circuit by just
67             // changing the tensor info rather than adding a reshape layer.
68             Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
69             if ((parentLayer.GetType() == armnn::LayerType::Constant) &&
70                 (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
71             {
72                 ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
73 
74                 constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
75                     ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
76                 constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
77             }
78             else
79             {
80                 const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
81                 const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
82                 ReshapeLayer* reshapeLayer =
83                     graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
84                 reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
85             }
86         }
87     }
88 
89 protected:
90     AddBroadcastReshapeLayerImpl()  = default;
91     ~AddBroadcastReshapeLayerImpl() = default;
92 };
93 
94 using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>;
95 
96 }    // namespace optimizations
97 }    // namespace armnn
98